MLIR 22.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Matchers.h"
25#include "mlir/Pass/Pass.h"
26
28
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
33#include <optional>
34
35namespace mlir {
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
38} // namespace mlir
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43// Define commonly used chipsets versions for convenience.
44constexpr Chipset kGfx908 = Chipset(9, 0, 8);
45constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
46constexpr Chipset kGfx942 = Chipset(9, 4, 2);
47constexpr Chipset kGfx950 = Chipset(9, 5, 0);
48constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
49
50/// Convert an unsigned number `val` to i32.
51static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
52 Location loc, Value val) {
53 IntegerType i32 = rewriter.getI32Type();
54 // Force check that `val` is of int type.
55 auto valTy = cast<IntegerType>(val.getType());
56 if (i32 == valTy)
57 return val;
58 return valTy.getWidth() > 32
59 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61}
62
63static Value createI32Constant(ConversionPatternRewriter &rewriter,
64 Location loc, int32_t value) {
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
66}
67
68/// Convert an unsigned number `val` to i64.
69static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
70 Location loc, Value val) {
71 IntegerType i64 = rewriter.getI64Type();
72 // Force check that `val` is of int type.
73 auto valTy = cast<IntegerType>(val.getType());
74 if (i64 == valTy)
75 return val;
76 return valTy.getWidth() > 64
77 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
79}
80
81static Value createI64Constant(ConversionPatternRewriter &rewriter,
82 Location loc, int64_t value) {
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
84}
85
86/// Returns the linear index used to access an element in the memref.
87static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
88 Location loc, MemRefDescriptor &memRefDescriptor,
90 IntegerType i32 = rewriter.getI32Type();
92 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
93 if (stride != 1) { // Skip if stride is 1.
94 Value strideValue =
95 ShapedType::isDynamic(stride)
96 ? convertUnsignedToI32(rewriter, loc,
97 memRefDescriptor.stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
100 }
101 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
102 : increment;
103 }
104 return index ? index : createI32Constant(rewriter, loc, 0);
105}
106
107/// Compute the contents of the `num_records` field for a given memref
108/// descriptor - that is, the number of bytes that's one element past the
109/// greatest possible valid index into the memref.
110static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
111 MemRefType memrefType,
112 MemRefDescriptor &memrefDescriptor,
113 ArrayRef<int64_t> strides,
114 int64_t elementByteWidth) {
115 if (memrefType.hasStaticShape() &&
116 !llvm::any_of(strides, ShapedType::isDynamic)) {
117 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
118 ArrayRef<int64_t> shape = memrefType.getShape();
119 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
120 size = std::max(shape[i] * strides[i], size);
121 size = size * elementByteWidth;
122 return createI64Constant(rewriter, loc, size);
123 }
124 Value maxIndex;
125 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
126 Value size = memrefDescriptor.size(rewriter, loc, i);
127 Value stride = memrefDescriptor.stride(rewriter, loc, i);
128 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
129 maxIndex = maxIndex
130 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
131 : maxThisDim;
132 }
133 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
134 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
135 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
136}
137
138static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
139 Value basePointer, Value numRecords,
140 bool boundsCheck, amdgpu::Chipset chipset,
141 Value cacheSwizzleStride = nullptr,
142 unsigned addressSpace = 8) {
143 // The stride value is generally 0. However, on MI-300 and onward, you can
144 // enable a cache swizzling mode by setting bit 14 of the stride field
145 // and setting that stride to a cache stride.
146 Type i16 = rewriter.getI16Type();
147 Value stride;
148 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
149 Value cacheStrideZext =
150 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
151 Value swizzleBit = LLVM::ConstantOp::create(
152 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
153 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
154 /*isDisjoint=*/true);
155 } else {
156 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
157 rewriter.getI16IntegerAttr(0));
158 }
159 // Get the number of elements.
160 // Flag word:
161 // bits 0-11: dst sel, ignored by these intrinsics
162 // bits 12-14: data format (ignored, must be nonzero, 7=float)
163 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
164 // bit 19: In nested heap (0 here)
165 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
166 // bits 21-22: Index stride for swizzles (N/A)
167 // bit 23: Add thread ID (0)
168 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
169 // bits 25-26: Reserved (0)
170 // bit 27: Buffer is non-volatile (CDNA only)
171 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
172 // none, 3 = either swizzles or testing against offset field) RDNA only
173 // bits 30-31: Type (must be 0)
174 uint32_t flags = (7 << 12) | (4 << 15);
175 if (chipset.majorVersion >= 10) {
176 flags |= (1 << 24);
177 uint32_t oob = boundsCheck ? 3 : 2;
178 flags |= (oob << 28);
179 }
180 Value flagsConst = createI32Constant(rewriter, loc, flags);
181 Type rsrcType =
182 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
183 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
184 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
185 return resource;
186}
187
188namespace {
189struct FatRawBufferCastLowering
190 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
191 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
192 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
193 chipset(chipset) {}
194
195 Chipset chipset;
196
197 LogicalResult
198 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter) const override {
200 Location loc = op.getLoc();
201 Value memRef = adaptor.getSource();
202 Value unconvertedMemref = op.getSource();
203 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
204 MemRefDescriptor descriptor(memRef);
205
206 DataLayout dataLayout = DataLayout::closest(op);
207 int64_t elementByteWidth =
208 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
209
210 int64_t unusedOffset = 0;
211 SmallVector<int64_t, 5> strideVals;
212 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
213 return op.emitOpError("Can't lower non-stride-offset memrefs");
214
215 Value numRecords = adaptor.getValidBytes();
216 if (!numRecords)
217 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
218 strideVals, elementByteWidth);
219
220 Value basePointer =
221 adaptor.getResetOffset()
222 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
223 memrefType)
224 : descriptor.alignedPtr(rewriter, loc);
225
226 Value offset = adaptor.getResetOffset()
227 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
228 rewriter.getIndexAttr(0))
229 : descriptor.offset(rewriter, loc);
230
231 bool hasSizes = memrefType.getRank() > 0;
232 // No need to unpack() and pack() all the individual sizes and strides,
233 // so we'll just extract the arrays.
234 Value sizes = hasSizes
235 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
237 : Value{};
238 Value strides =
239 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
241 : Value{};
242
243 Value fatPtr = makeBufferRsrc(
244 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
245 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
246
247 Value result = MemRefDescriptor::poison(
248 rewriter, loc,
249 getTypeConverter()->convertType(op.getResult().getType()));
250 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
251 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
252 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
254 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
256 if (hasSizes) {
257 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
259 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
261 }
262 rewriter.replaceOp(op, result);
263 return success();
264 }
265};
266
267/// Define lowering patterns for raw buffer ops
268template <typename GpuOp, typename Intrinsic>
269struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
270 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
271 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
272
273 Chipset chipset;
274 static constexpr uint32_t maxVectorOpWidth = 128;
275
276 LogicalResult
277 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override {
279 Location loc = gpuOp.getLoc();
280 Value memref = adaptor.getMemref();
281 Value unconvertedMemref = gpuOp.getMemref();
282 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
283
284 if (chipset.majorVersion < 9)
285 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
286
287 Value storeData = adaptor.getODSOperands(0)[0];
288 if (storeData == memref) // no write component to this op
289 storeData = Value();
290 Type wantedDataType;
291 if (storeData)
292 wantedDataType = storeData.getType();
293 else
294 wantedDataType = gpuOp.getODSResults(0)[0].getType();
295
296 Value atomicCmpData = Value();
297 // Operand index 1 of a load is the indices, trying to read them can crash.
298 if (storeData) {
299 Value maybeCmpData = adaptor.getODSOperands(1)[0];
300 if (maybeCmpData != memref)
301 atomicCmpData = maybeCmpData;
302 }
303
304 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
305
306 Type i32 = rewriter.getI32Type();
307
308 // Get the type size in bytes.
309 DataLayout dataLayout = DataLayout::closest(gpuOp);
310 int64_t elementByteWidth =
311 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
312 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
313
314 // If we want to load a vector<NxT> with total size <= 32
315 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
316 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
317 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
318 // so bitcast any floats to integers.
319 Type llvmBufferValType = llvmWantedDataType;
320 if (atomicCmpData) {
321 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
322 llvmBufferValType = this->getTypeConverter()->convertType(
323 rewriter.getIntegerType(floatType.getWidth()));
324 }
325 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
326 uint32_t vecLen = dataVector.getNumElements();
327 uint32_t elemBits =
328 dataLayout.getTypeSizeInBits(dataVector.getElementType());
329 uint32_t totalBits = elemBits * vecLen;
330 bool usePackedFp16 =
331 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
332 if (totalBits > maxVectorOpWidth)
333 return gpuOp.emitOpError(
334 "Total width of loads or stores must be no more than " +
335 Twine(maxVectorOpWidth) + " bits, but we call for " +
336 Twine(totalBits) +
337 " bits. This should've been caught in validation");
338 if (!usePackedFp16 && elemBits < 32) {
339 if (totalBits > 32) {
340 if (totalBits % 32 != 0)
341 return gpuOp.emitOpError("Load or store of more than 32-bits that "
342 "doesn't fit into words. Can't happen\n");
343 llvmBufferValType = this->typeConverter->convertType(
344 VectorType::get(totalBits / 32, i32));
345 } else {
346 llvmBufferValType = this->typeConverter->convertType(
347 rewriter.getIntegerType(totalBits));
348 }
349 }
350 }
351 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
352 // Buffer intrinsics doesn't support 1-element vectors, cast them to
353 // scalars.
354 if (vecType.getNumElements() == 1)
355 llvmBufferValType = vecType.getElementType();
356 }
357
358 SmallVector<Value, 6> args;
359 if (storeData) {
360 if (llvmBufferValType != llvmWantedDataType) {
361 Value castForStore = LLVM::BitcastOp::create(
362 rewriter, loc, llvmBufferValType, storeData);
363 args.push_back(castForStore);
364 } else {
365 args.push_back(storeData);
366 }
367 }
368
369 if (atomicCmpData) {
370 if (llvmBufferValType != llvmWantedDataType) {
371 Value castForCmp = LLVM::BitcastOp::create(
372 rewriter, loc, llvmBufferValType, atomicCmpData);
373 args.push_back(castForCmp);
374 } else {
375 args.push_back(atomicCmpData);
376 }
377 }
378
379 // Construct buffer descriptor from memref, attributes
380 int64_t offset = 0;
381 SmallVector<int64_t, 5> strides;
382 if (failed(memrefType.getStridesAndOffset(strides, offset)))
383 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
384
385 MemRefDescriptor memrefDescriptor(memref);
386
387 Value ptr = memrefDescriptor.bufferPtr(
388 rewriter, loc, *this->getTypeConverter(), memrefType);
389 Value numRecords = getNumRecords(
390 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
391 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
392 adaptor.getBoundsCheck(), chipset);
393 args.push_back(resource);
394
395 // Indexing (voffset)
396 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
397 adaptor.getIndices(), strides);
398 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
399 indexOffset && *indexOffset > 0) {
400 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
401 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
402 extraOffsetConst)
403 : extraOffsetConst;
404 }
405 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
406 args.push_back(voffset);
407
408 // SGPR offset.
409 Value sgprOffset = adaptor.getSgprOffset();
410 if (!sgprOffset)
411 sgprOffset = createI32Constant(rewriter, loc, 0);
412 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
413 args.push_back(sgprOffset);
414
415 // bit 0: GLC = 0 (atomics drop value, less coherency)
416 // bits 1-2: SLC, DLC = 0 (similarly)
417 // bit 3: swizzled (0 for raw)
418 args.push_back(createI32Constant(rewriter, loc, 0));
419
420 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
421 llvmBufferValType);
422 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
423 ArrayRef<NamedAttribute>());
424 if (lowered->getNumResults() == 1) {
425 Value replacement = lowered->getResult(0);
426 if (llvmBufferValType != llvmWantedDataType) {
427 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
429 }
430 rewriter.replaceOp(gpuOp, replacement);
431 } else {
432 rewriter.eraseOp(gpuOp);
433 }
434 return success();
435 }
436};
437
438// TODO: AMDGPU backend already have all this bitpacking logic, we should move
439// it to some common place.
440/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
441/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
442/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
443/// Vmcnt = Waitcnt[15:10] (gfx11)
444/// Expcnt = Waitcnt[6:4] (pre-gfx11)
445/// Expcnt = Waitcnt[2:0] (gfx11)
446/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
447/// Lgkmcnt = Waitcnt[13:8] (gfx10)
448/// Lgkmcnt = Waitcnt[9:4] (gfx11)
449static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
450 unsigned expcnt, unsigned lgkmcnt) {
451 if (chipset.majorVersion < 9) {
452 vmcnt = std::min(15u, vmcnt);
453 expcnt = std::min(7u, expcnt);
454 lgkmcnt = std::min(15u, lgkmcnt);
455 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
456 }
457 if (chipset.majorVersion == 9) {
458 vmcnt = std::min(63u, vmcnt);
459 expcnt = std::min(7u, expcnt);
460 lgkmcnt = std::min(15u, lgkmcnt);
461 unsigned lowBits = vmcnt & 0xF;
462 unsigned highBits = (vmcnt >> 4) << 14;
463 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
464 return lowBits | highBits | otherCnts;
465 }
466 if (chipset.majorVersion == 10) {
467 vmcnt = std::min(63u, vmcnt);
468 expcnt = std::min(7u, expcnt);
469 lgkmcnt = std::min(63u, lgkmcnt);
470 unsigned lowBits = vmcnt & 0xF;
471 unsigned highBits = (vmcnt >> 4) << 14;
472 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
473 return lowBits | highBits | otherCnts;
474 }
475 if (chipset.majorVersion == 11) {
476 vmcnt = std::min(63u, vmcnt);
477 expcnt = std::min(7u, expcnt);
478 lgkmcnt = std::min(63u, lgkmcnt);
479 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
480 }
481 return failure();
482}
483
484struct MemoryCounterWaitOpLowering
485 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
486 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
488 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
489 chipset(chipset) {}
492
493 LogicalResult
494 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
495 ConversionPatternRewriter &rewriter) const override {
496 if (chipset.majorVersion >= 12) {
497 Location loc = op.getLoc();
498 if (std::optional<int> ds = adaptor.getDs())
499 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
501 if (std::optional<int> load = adaptor.getLoad())
502 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
503
504 if (std::optional<int> store = adaptor.getStore())
505 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
506
507 if (std::optional<int> exp = adaptor.getExp())
508 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
509
510 if (std::optional<int> tensor = adaptor.getTensor())
511 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
512
513 rewriter.eraseOp(op);
514 return success();
515 }
516
517 if (adaptor.getTensor())
518 return op.emitOpError("unsupported chipset");
519
520 auto getVal = [](Attribute attr) -> unsigned {
521 if (attr)
522 return cast<IntegerAttr>(attr).getInt();
523
524 // This value will be clamped to the maximum value for the chipset.
525 return 1024;
526 };
527 unsigned ds = getVal(adaptor.getDsAttr());
528 unsigned exp = getVal(adaptor.getExpAttr());
529
530 unsigned vmcnt = 1024;
531 Attribute load = adaptor.getLoadAttr();
532 Attribute store = adaptor.getStoreAttr();
533 if (load && store) {
534 vmcnt = getVal(load) + getVal(store);
535 } else if (load) {
536 vmcnt = getVal(load);
537 } else if (store) {
538 vmcnt = getVal(store);
539 }
540
541 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
542 if (failed(waitcnt))
543 return op.emitOpError("unsupported chipset");
544
545 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
546 return success();
547 }
548};
549
550struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
551 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
552 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
553
554 Chipset chipset;
555
556 LogicalResult
557 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
558 ConversionPatternRewriter &rewriter) const override {
559 Location loc = op.getLoc();
560 // This ensures that waits on global memory aren't introduced on
561 // chips that don't have the BackOffBarrier feature enabled in LLVM.
562 bool requiresInlineAsm = chipset < kGfx90a;
563
564 Attribute mmra =
565 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
566 // Note: while there *is* a workgroup-one-as scope, this, when combined with
567 // the MMRA, will lead to the fence having no effect. This is because the
568 // codepaths for an atomic load or store will observe that a
569 // one-address-space atomic to LDS requires no synchronization because
570 // operations on LDS are totally ordered with respect to each other, and so
571 // will not emit the correct waitcnt operations that these fences are
572 // intended to produce. Therefore, we use a broader type of fence and rely
573 // on the MMRA to relax it to the semantics we want.
574 StringRef scope = "workgroup";
575
576 auto relFence = LLVM::FenceOp::create(rewriter, loc,
577 LLVM::AtomicOrdering::release, scope);
578 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
579 if (requiresInlineAsm) {
580 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
581 LLVM::AsmDialect::AD_ATT);
582 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
583 const char *constraints = "";
584 LLVM::InlineAsmOp::create(
585 rewriter, loc,
586 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
587 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
588 /*is_align_stack=*/false, LLVM::TailCallKind::None,
589 /*asm_dialect=*/asmDialectAttr,
590 /*operand_attrs=*/ArrayAttr());
591 } else if (chipset.majorVersion < 12) {
592 ROCDL::SBarrierOp::create(rewriter, loc);
593 } else {
594 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
595 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
596 }
597
598 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
599 LLVM::AtomicOrdering::acquire, scope);
600 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
601 rewriter.replaceOp(op, acqFence);
602 return success();
603 }
604};
605
606struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
607 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
608 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
609
610 Chipset chipset;
611
612 LogicalResult
613 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
614 ConversionPatternRewriter &rewriter) const override {
615 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
616 (uint32_t)op.getOpts());
617 return success();
618 }
619};
620
621} // namespace
622
623/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
624/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
625///
626/// Specifically:
627/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
628/// allows bf16. Newer MFMAs support bf16 types on operand, check
629/// IntrinsicsAMDGPU.td file for reference.
630/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
631/// instead, which is what the f8f6f4 intrinsics use.
632/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
633/// integer.
634///
635/// Note that the type of `input` has already been LLVM type converted:
636/// therefore 8-bit and smaller floats are represented as their corresponding
637/// `iN` integers.
638static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
639 Location loc, Value input,
640 bool allowBf16 = true) {
641 Type inputType = input.getType();
642 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
643 if (vectorType.getElementType().isBF16() && !allowBf16)
644 return LLVM::BitcastOp::create(
645 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
646 if (vectorType.getElementType().isInteger(8) &&
647 vectorType.getNumElements() <= 8)
648 return LLVM::BitcastOp::create(
649 rewriter, loc,
650 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
651 if (isa<IntegerType>(vectorType.getElementType()) &&
652 vectorType.getElementTypeBitWidth() <= 8) {
653 int64_t numWords = llvm::divideCeil(
654 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
655 32);
656 return LLVM::BitcastOp::create(
657 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
658 input);
659 }
660 }
661 return input;
662}
663
664/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
665static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
666 Location loc, Value input,
667 bool allowBf16 = true) {
668 Type inputType = input.getType();
669 auto vectorType = cast<VectorType>(inputType);
670 // bf16 -> i16 when not allowed (pre-gfx950).
671 if (vectorType.getElementType().isBF16() && !allowBf16)
672 return LLVM::BitcastOp::create(
673 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
674 // i8/fp8 vectors -> vector<Nxi32>.
675 if (isa<IntegerType>(vectorType.getElementType()) &&
676 vectorType.getElementTypeBitWidth() <= 8) {
677 int64_t numWords = llvm::divideCeil(
678 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
679 return LLVM::BitcastOp::create(
680 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
681 }
682 return input;
683}
684
685/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
686/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
687///
688/// Specifically:
689/// 1. If `input` is a i8 value, zero extend it to i32
690/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
691///
692/// Note that the type of `input` has already been LLVM type converted:
693/// therefore 8-bit and smaller floats are represented as their corresponding
694/// `iN` integers.
695static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
696 Value input) {
697 return TypeSwitch<Type, Value>(input.getType())
698 .Case([&](IntegerType) {
699 // Handle scalar i8: zero extend to i32.
700 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
701 input);
702 })
703 .Case([&](VectorType vectorType) {
704 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
705 int64_t numElements = vectorType.getNumElements();
706 assert((numElements == 4 || numElements == 8) &&
707 "scale operand must be a vector of length 4 or 8");
708 IntegerType outputType =
709 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
710 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
711 })
712 .DefaultUnreachable("unexpected input type for scale operand");
713}
714
715/// Maps f8 scale element types to WMMA scale format codes.
716static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
718 .Case([](Float8E8M0FNUType) { return 0; })
719 .Case([](Float8E4M3FNType) { return 2; })
720 .Default(std::nullopt);
721}
722
723/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
724/// and scale block size (16 or 32).
725static std::optional<StringRef>
727 if (m == 16 && n == 16 && k == 128)
728 return isScale16
729 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
730 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
731
732 if (m == 32 && n == 16 && k == 128)
733 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
734 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
735
736 return std::nullopt;
737}
738
739/// Push an input operand. If it is a float type, nothing to do. If it is
740/// an integer type, then we need to also push its signdness (1 for signed, 0
741/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
742/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
743/// We also need to convert bfloat inputs to i16 to account for the bfloat
744/// intrinsics having been defined before the AMD backend supported bfloat. We
745/// similarly need to pack 8-bit float types into integers as if they were i8
746/// (which they are for the backend's purposes).
748 ConversionPatternRewriter &rewriter, Location loc,
749 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
750 Value mlirInput, SmallVectorImpl<Value> &operands,
751 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
752 Type inputType = llvmInput.getType();
753 auto vectorType = dyn_cast<VectorType>(inputType);
754 if (!vectorType) {
755 operands.push_back(llvmInput);
756 return;
757 }
758 Type elemType = vectorType.getElementType();
759 if (elemType.getIntOrFloatBitWidth() > 8) {
760 operands.push_back(llvmInput);
761 return;
762 }
763
764 // We need to check the type of the input before conversion to properly test
765 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
766 // fp8/int8 information is lost during the conversion process.
767 auto mlirInputType = cast<VectorType>(mlirInput.getType());
768 bool isInputInteger = mlirInputType.getElementType().isInteger();
769 if (isInputInteger) {
770 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
771 bool localIsUnsigned = isUnsigned;
772 if (elemType.isUnsignedInteger()) {
773 localIsUnsigned = true;
774 } else if (elemType.isSignedInteger()) {
775 localIsUnsigned = false;
776 }
777 attrs.push_back(
778 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
779 }
780
781 int64_t numBits =
782 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
783 Type i32 = rewriter.getI32Type();
784 Type intrinsicInType = numBits <= 32
785 ? (Type)rewriter.getIntegerType(numBits)
786 : (Type)VectorType::get(numBits / 32, i32);
787 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
788 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
789 loc, llvmIntrinsicInType, llvmInput);
790 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
791 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
792 // Add in the zeros here.
793 if (numBits < 32)
794 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
795 operands.push_back(castInput);
796}
797
798/// Push the output operand. For many cases this is only pushing the output in
799/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
800/// since the same numbers of VGPRs is used, we need to decide if to store the
801/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
802/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
803/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
804/// as the instructions have been changed to return fewer registers instead.
805static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
806 Location loc,
807 const TypeConverter *typeConverter,
808 Value output, int32_t subwordOffset,
809 bool clamp, SmallVectorImpl<Value> &operands,
811 Type inputType = output.getType();
812 auto vectorType = dyn_cast<VectorType>(inputType);
813 Type elemType = vectorType.getElementType();
814 operands.push_back(output);
815 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
816 attrs.push_back(
817 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
818 } else if (elemType.isInteger(32)) {
819 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
820 }
821}
822
823/// Return true if `type` is the E5M2 variant of an 8-bit float that is
824/// supported by the `_bf8` instructions on the given `chipset`.
825static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
826 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
827 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
828}
829
830/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
831/// supported by the `_fp8` instructions on the given `chipset`.
832static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
833 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
834 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
835}
836
837/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
838/// if one exists. This includes checking to ensure the intrinsic is supported
839/// on the architecture you are compiling for.
840static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
841 Chipset chipset) {
842 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
843 b = mfma.getBlocks();
844 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
845 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
846
847 if (sourceElem.isF32() && destElem.isF32()) {
848 if (mfma.getReducePrecision() && chipset >= kGfx942) {
849 if (m == 32 && n == 32 && k == 4 && b == 1)
850 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
851 if (m == 16 && n == 16 && k == 8 && b == 1)
852 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
853 }
854 if (m == 32 && n == 32 && k == 1 && b == 2)
855 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
856 if (m == 16 && n == 16 && k == 1 && b == 4)
857 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
858 if (m == 4 && n == 4 && k == 1 && b == 16)
859 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
860 if (m == 32 && n == 32 && k == 2 && b == 1)
861 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
862 if (m == 16 && n == 16 && k == 4 && b == 1)
863 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
864 }
865
866 if (sourceElem.isF16() && destElem.isF32()) {
867 if (chipset >= kGfx950) {
868 if (m == 32 && n == 32 && k == 16 && b == 1)
869 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
870 if (m == 16 && n == 16 && k == 32 && b == 1)
871 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
872 }
873 if (m == 32 && n == 32 && k == 4 && b == 2)
874 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
875 if (m == 16 && n == 16 && k == 4 && b == 4)
876 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
877 if (m == 4 && n == 4 && k == 4 && b == 16)
878 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
879 if (m == 32 && n == 32 && k == 8 && b == 1)
880 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
881 if (m == 16 && n == 16 && k == 16 && b == 1)
882 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
883 }
884
885 if (sourceElem.isBF16() && destElem.isF32()) {
886 if (chipset >= kGfx950) {
887 if (m == 32 && n == 32 && k == 16 && b == 1)
888 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 && b == 1)
890 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
891 }
892 if (chipset >= kGfx90a) {
893 if (m == 32 && n == 32 && k == 4 && b == 2)
894 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
895 if (m == 16 && n == 16 && k == 4 && b == 4)
896 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
897 if (m == 4 && n == 4 && k == 4 && b == 16)
898 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
899 if (m == 32 && n == 32 && k == 8 && b == 1)
900 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
901 if (m == 16 && n == 16 && k == 16 && b == 1)
902 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
903 }
904 if (m == 32 && n == 32 && k == 2 && b == 2)
905 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
906 if (m == 16 && n == 16 && k == 2 && b == 4)
907 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
908 if (m == 4 && n == 4 && k == 2 && b == 16)
909 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
910 if (m == 32 && n == 32 && k == 4 && b == 1)
911 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
912 if (m == 16 && n == 16 && k == 8 && b == 1)
913 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
914 }
915
916 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
917 if (chipset >= kGfx950) {
918 if (m == 32 && n == 32 && k == 32 && b == 1)
919 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
920 if (m == 16 && n == 16 && k == 64 && b == 1)
921 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
922 }
923 if (m == 32 && n == 32 && k == 4 && b == 2)
924 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
925 if (m == 16 && n == 16 && k == 4 && b == 4)
926 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
927 if (m == 4 && n == 4 && k == 4 && b == 16)
928 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
929 if (m == 32 && n == 32 && k == 8 && b == 1)
930 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
931 if (m == 16 && n == 16 && k == 16 && b == 1)
932 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
933 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
934 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
935 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
936 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
937 }
938
939 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
940 if (m == 16 && n == 16 && k == 4 && b == 1)
941 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
942 if (m == 4 && n == 4 && k == 4 && b == 4)
943 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
944 }
945
946 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
947 // Known to be correct because there are no scalar f8 instructions and
948 // because a length mismatch will have been caught by the verifier.
949 Type sourceBElem =
950 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
951 if (m == 16 && n == 16 && k == 32 && b == 1) {
952 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
953 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
954 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
955 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
956 }
957 if (m == 32 && n == 32 && k == 16 && b == 1) {
958 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
959 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
960 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
961 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
962 }
963 }
964
965 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
966 Type sourceBElem =
967 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
968 if (m == 16 && n == 16 && k == 32 && b == 1) {
969 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
970 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
971 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
972 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
973 }
974 if (m == 32 && n == 32 && k == 16 && b == 1) {
975 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
976 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
977 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
978 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
979 }
980 }
981
982 return std::nullopt;
983}
984
985static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
987 .Case([](Float8E4M3FNType) { return 0u; })
988 .Case([](Float8E5M2Type) { return 1u; })
989 .Case([](Float6E2M3FNType) { return 2u; })
990 .Case([](Float6E3M2FNType) { return 3u; })
991 .Case([](Float4E2M1FNType) { return 4u; })
992 .Default(std::nullopt);
993}
994
995/// If there is a scaled MFMA instruction for the input element types `aType`
996/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
997/// blocks) on the given `chipset`, return a tuple consisting of the
998/// OperationName of the intrinsic and the type codes that need to be passed to
999/// that intrinsic. Note that this is also used to implement some un-scaled
1000/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1001/// MFMA with a scale of 0.
1002static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1003mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1004 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1005 aType = getElementTypeOrSelf(aType);
1006 bType = getElementTypeOrSelf(bType);
1007 destType = getElementTypeOrSelf(destType);
1008
1009 if (chipset < kGfx950)
1010 return std::nullopt;
1011 if (!isa<Float32Type>(destType))
1012 return std::nullopt;
1013
1014 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1015 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1016 if (!aTypeCode || !bTypeCode)
1017 return std::nullopt;
1018
1019 if (m == 32 && n == 32 && k == 64 && b == 1)
1020 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1021 *aTypeCode, *bTypeCode};
1022 if (m == 16 && n == 16 && k == 128 && b == 1)
1023 return std::tuple{
1024 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1025 *bTypeCode};
1026
1027 return std::nullopt;
1028}
1029
1030static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1031mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1033 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1034 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1035 mfma.getBlocks(), chipset);
1036}
1037
1038static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1039mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1040 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1041 smfma.getSourceB().getType(),
1042 smfma.getDestC().getType(), smfma.getM(),
1043 smfma.getN(), smfma.getK(), 1u, chipset);
1044}
1045
1046/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1047/// for RDNA3/4 architectures.
1048static std::optional<StringRef>
1049wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1050 Type elemDestType, uint32_t k, bool isRDNA3) {
1051 using fp8 = Float8E4M3FNType;
1052 using bf8 = Float8E5M2Type;
1053
1054 // Handle k == 16 for RDNA3/4.
1055 if (k == 16) {
1056 // Common patterns for RDNA3 and RDNA4.
1057 if (elemSourceType.isF16() && elemDestType.isF32())
1058 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1059 if (elemSourceType.isBF16() && elemDestType.isF32())
1060 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1061 if (elemSourceType.isF16() && elemDestType.isF16())
1062 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1063 if (elemSourceType.isBF16() && elemDestType.isBF16())
1064 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1065 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1066 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1067
1068 // RDNA3 specific patterns.
1069 if (isRDNA3) {
1070 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1071 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1072 return std::nullopt;
1073 }
1074
1075 // RDNA4 specific patterns (fp8/bf8).
1076 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1077 elemDestType.isF32())
1078 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1079 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1080 elemDestType.isF32())
1081 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1082 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1083 elemDestType.isF32())
1084 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1085 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1086 elemDestType.isF32())
1087 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1088 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1089 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1090
1091 return std::nullopt;
1092 }
1093
1094 // Handle k == 32 for RDNA4.
1095 if (k == 32 && !isRDNA3) {
1096 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1097 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1098 }
1099
1100 return std::nullopt;
1101}
1102
1103/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1104/// for the gfx1250 architecture.
1105static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1106 Type elemBSourceType,
1107 Type elemDestType,
1108 uint32_t k) {
1109 using fp8 = Float8E4M3FNType;
1110 using bf8 = Float8E5M2Type;
1111
1112 if (k == 4) {
1113 if (elemSourceType.isF32() && elemDestType.isF32())
1114 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1115
1116 return std::nullopt;
1117 }
1118
1119 if (k == 32) {
1120 if (elemSourceType.isF16() && elemDestType.isF32())
1121 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1122 if (elemSourceType.isBF16() && elemDestType.isF32())
1123 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1124 if (elemSourceType.isF16() && elemDestType.isF16())
1125 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1126 if (elemSourceType.isBF16() && elemDestType.isBF16())
1127 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1128
1129 return std::nullopt;
1130 }
1131
1132 if (k == 64) {
1133 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1134 if (elemDestType.isF32())
1135 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1136 if (elemDestType.isF16())
1137 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1138 }
1139 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1140 if (elemDestType.isF32())
1141 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1142 if (elemDestType.isF16())
1143 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1144 }
1145 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1146 if (elemDestType.isF32())
1147 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1148 if (elemDestType.isF16())
1149 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1150 }
1151 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1152 if (elemDestType.isF32())
1153 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1154 if (elemDestType.isF16())
1155 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1156 }
1157 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1158 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1159
1160 return std::nullopt;
1161 }
1162
1163 if (k == 128) {
1164 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1165 if (elemDestType.isF32())
1166 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1167 if (elemDestType.isF16())
1168 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1169 }
1170 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1171 if (elemDestType.isF32())
1172 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1173 if (elemDestType.isF16())
1174 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1175 }
1176 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1177 if (elemDestType.isF32())
1178 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1179 if (elemDestType.isF16())
1180 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1181 }
1182 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1183 if (elemDestType.isF32())
1184 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1185 if (elemDestType.isF16())
1186 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1187 }
1188
1189 return std::nullopt;
1190 }
1191
1192 return std::nullopt;
1193}
1194
1195/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1196/// operation if one exists. This includes checking to ensure the intrinsic is
1197/// supported on the architecture you are compiling for.
1198static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1199 Chipset chipset) {
1200 bool isGfx950 = chipset >= kGfx950;
1201 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1202 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1203
1204 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1205 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1206 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1207 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1208
1209 if (m == 16 && n == 16 && k == 32) {
1210 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1211 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1212 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1213 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1214 }
1215
1216 if (m == 16 && n == 16 && k == 64) {
1217 if (isGfx950) {
1218 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1219 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1220 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1221 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1222 }
1223 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1224 destElem.isInteger(32))
1225 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1226 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1227 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1228 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1229 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1230 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1231 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1232 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1233 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1234 }
1235
1236 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1237 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1238 destElem.isInteger(32))
1239 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1240 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1241 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1242 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1243 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1244 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1245 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1246 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1247 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1248 }
1249
1250 if (m == 32 && n == 32 && k == 16) {
1251 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1252 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1253 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1254 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1255 }
1256
1257 if (m == 32 && n == 32 && k == 32) {
1258 if (isGfx950) {
1259 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1260 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1261 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1262 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1263 }
1264 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1265 destElem.isInteger(32))
1266 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1267 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1268 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1269 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1270 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1271 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1272 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1273 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1274 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1275 }
1276
1277 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1278 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1279 destElem.isInteger(32))
1280 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1281 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1282 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1283 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1284 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1285 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1286 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1287 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1288 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1289 }
1290
1291 return std::nullopt;
1292}
1293
1294/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1295/// if one exists. This includes checking to ensure the intrinsic is supported
1296/// on the architecture you are compiling for.
1297static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1298 Chipset chipset) {
1299 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1300 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1301 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1302 Type elemSourceType = sourceVectorType.getElementType();
1303 Type elemBSourceType = sourceBVectorType.getElementType();
1304 Type elemDestType = destVectorType.getElementType();
1305
1306 const uint32_t k = wmma.getK();
1307 const bool isRDNA3 = chipset.majorVersion == 11;
1308 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1309
1310 // Handle RDNA3 and RDNA4.
1311 if (isRDNA3 || isRDNA4)
1312 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1313 k, isRDNA3);
1314
1315 // Handle gfx1250.
1316 if (chipset == kGfx1250)
1317 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1318 elemDestType, k);
1319
1320 return std::nullopt;
1321}
1322
1323namespace {
1324struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1325 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1326 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1327
1328 Chipset chipset;
1329
1330 LogicalResult
1331 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1332 ConversionPatternRewriter &rewriter) const override {
1333 Location loc = op.getLoc();
1334 Type outType = typeConverter->convertType(op.getDestD().getType());
1335 Type intrinsicOutType = outType;
1336 if (auto outVecType = dyn_cast<VectorType>(outType))
1337 if (outVecType.getElementType().isBF16())
1338 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1339
1340 if (chipset.majorVersion != 9 || chipset < kGfx908)
1341 return op->emitOpError("MFMA only supported on gfx908+");
1342 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1343 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1344 if (chipset < kGfx942)
1345 return op.emitOpError("negation unsupported on older than gfx942");
1346 getBlgpField |=
1347 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1348 }
1349 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1350 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1351 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1352 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1353 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1354
1355 bool isScaled =
1356 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1357 if (isScaled &&
1358 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1359 return op.emitOpError(
1360 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1361 "be scaled as those fields are used for type information");
1362 }
1363
1364 StringRef intrinsicName =
1365 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1366 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1367 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1368 bool allowBf16 = [&]() {
1369 if (chipset < kGfx950)
1370 return false;
1371 if (isScaled)
1372 return true;
1373 return intrinsicName.contains("16x16x32.bf16") ||
1374 intrinsicName.contains("32x32x16.bf16");
1375 }();
1376 OperationState loweredOp(loc, intrinsicName);
1377 loweredOp.addTypes(intrinsicOutType);
1378 loweredOp.addOperands({packSmallFloatVectorOperand(
1379 rewriter, loc, adaptor.getSourceA(), allowBf16),
1381 rewriter, loc, adaptor.getSourceB(), allowBf16),
1382 adaptor.getDestC()});
1383 if (isScaled) {
1384 Value zero = createI32Constant(rewriter, loc, 0);
1385 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1386 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1387 createI32Constant(rewriter, loc, bTypeCode),
1388 /*scale A byte=*/zero, /*scale A=*/zero,
1389 /*scale B byte=*/zero, /*scale B=*/zero});
1390 } else {
1391 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1392 createI32Constant(rewriter, loc, op.getAbid()),
1393 createI32Constant(rewriter, loc, getBlgpField)});
1394 };
1395 Value lowered = rewriter.create(loweredOp)->getResult(0);
1396 if (outType != intrinsicOutType)
1397 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1398 rewriter.replaceOp(op, lowered);
1399 return success();
1400 }
1401};
1402
1403struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1404 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1405 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1406
1407 Chipset chipset;
1408
1409 LogicalResult
1410 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1411 ConversionPatternRewriter &rewriter) const override {
1412 Location loc = op.getLoc();
1413 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1414
1415 if (chipset.majorVersion != 9 || chipset < kGfx950)
1416 return op->emitOpError("scaled MFMA only supported on gfx908+");
1417 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1418 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1419 if (!maybeScaledIntrinsic.has_value())
1420 return op.emitOpError(
1421 "no intrinsic matching scaled MFMA size on given chipset");
1422
1423 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1424 OperationState loweredOp(loc, intrinsicName);
1425 loweredOp.addTypes(intrinsicOutType);
1426 loweredOp.addOperands(
1427 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1428 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1429 adaptor.getDestC()});
1430 Value scalesIdxA =
1431 createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1432 Value scalesIdxB =
1433 createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1434 loweredOp.addOperands(
1435 {createI32Constant(rewriter, loc, aTypeCode),
1436 createI32Constant(rewriter, loc, bTypeCode),
1437 /*scales idx A=*/scalesIdxA,
1438 /*scales A*/
1439 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1440 /*scales idx B=*/scalesIdxB,
1441 /*scales B*/
1442 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1443 Value lowered = rewriter.create(loweredOp)->getResult(0);
1444 rewriter.replaceOp(op, lowered);
1445 return success();
1446 }
1447};
1448
1449struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1450 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1451 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1452
1453 Chipset chipset;
1454
1455 LogicalResult
1456 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1457 ConversionPatternRewriter &rewriter) const override {
1458 Location loc = op.getLoc();
1459 auto outType =
1460 typeConverter->convertType<VectorType>(op.getDestC().getType());
1461 if (!outType)
1462 return rewriter.notifyMatchFailure(op, "type conversion failed");
1463
1464 // smfmac is supported on gfx942 and gfx950.
1465 if (chipset.majorVersion != 9 || chipset < kGfx942)
1466 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1467 bool isGfx950 = chipset >= kGfx950;
1468
1469 Value a = convertSparseMFMAVectorOperand(rewriter, loc,
1470 adaptor.getSourceA(), isGfx950);
1471 Value b = convertSparseMFMAVectorOperand(rewriter, loc,
1472 adaptor.getSourceB(), isGfx950);
1473 Value c = adaptor.getDestC();
1474
1475 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1476 if (!maybeIntrinsic.has_value())
1477 return op.emitOpError(
1478 "no intrinsic matching sparse MFMA on the given chipset");
1479
1480 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1481 Value sparseIdx = LLVM::BitcastOp::create(
1482 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1483
1484 OperationState loweredOp(loc, maybeIntrinsic.value());
1485 loweredOp.addTypes(outType);
1486 loweredOp.addOperands({a, b, c, sparseIdx,
1487 createI32Constant(rewriter, loc, op.getCbsz()),
1488 createI32Constant(rewriter, loc, op.getAbid())});
1489 Value lowered = rewriter.create(loweredOp)->getResult(0);
1490 rewriter.replaceOp(op, lowered);
1491 return success();
1492 }
1493};
1494
1495struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1496 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1497 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1498
1499 Chipset chipset;
1500
1501 LogicalResult
1502 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1503 ConversionPatternRewriter &rewriter) const override {
1504 Location loc = op.getLoc();
1505 auto outType =
1506 typeConverter->convertType<VectorType>(op.getDestD().getType());
1507 if (!outType)
1508 return rewriter.notifyMatchFailure(op, "type conversion failed");
1509
1510 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1511 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1512
1513 bool isGFX1250 = chipset >= kGfx1250;
1514
1515 // The WMMA operations represent vectors of bf16s as vectors of i16s
1516 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1517 // bitcast them back.
1518 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1519 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1520 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1521 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1522 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1523 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1524 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1525 VectorType rawOutType = outType;
1526 if (castOutToI16)
1527 rawOutType = outType.clone(rewriter.getI16Type());
1528 Value a = adaptor.getSourceA();
1529 if (castAToI16)
1530 a = LLVM::BitcastOp::create(rewriter, loc,
1531 aType.clone(rewriter.getI16Type()), a);
1532 Value b = adaptor.getSourceB();
1533 if (castBToI16)
1534 b = LLVM::BitcastOp::create(rewriter, loc,
1535 bType.clone(rewriter.getI16Type()), b);
1536 Value destC = adaptor.getDestC();
1537 if (castDestCToI16)
1538 destC = LLVM::BitcastOp::create(
1539 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1540
1541 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1542
1543 if (!maybeIntrinsic.has_value())
1544 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1545
1546 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1547 return op.emitOpError("subwordOffset not supported on gfx12+");
1548
1549 SmallVector<Value, 4> operands;
1550 SmallVector<NamedAttribute, 4> attrs;
1551 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1552 op.getSourceA(), operands, attrs, "signA");
1553 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1554 op.getSourceB(), operands, attrs, "signB");
1555 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1556 op.getSubwordOffset(), op.getClamp(), operands,
1557 attrs);
1558
1559 OperationState loweredOp(loc, *maybeIntrinsic);
1560 loweredOp.addTypes(rawOutType);
1561 loweredOp.addOperands(operands);
1562 loweredOp.addAttributes(attrs);
1563 Operation *lowered = rewriter.create(loweredOp);
1564
1565 Operation *maybeCastBack = lowered;
1566 if (rawOutType != outType)
1567 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1568 lowered->getResult(0));
1569 rewriter.replaceOp(op, maybeCastBack->getResults());
1570
1571 return success();
1572 }
1573};
1574
1575struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1576 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1577 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1578
1579 Chipset chipset;
1580
1581 LogicalResult
1582 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1583 ConversionPatternRewriter &rewriter) const override {
1584 Location loc = op.getLoc();
1585 auto outType =
1586 typeConverter->convertType<VectorType>(op.getDestD().getType());
1587 if (!outType)
1588 return rewriter.notifyMatchFailure(op, "type conversion failed");
1589
1590 if (chipset < kGfx1250)
1591 return op->emitOpError("WMMA scale only supported on gfx1250+");
1592
1593 int64_t m = op.getM();
1594 int64_t n = op.getN();
1595 int64_t k = op.getK();
1596
1597 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1598 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1599
1600 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1601 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1602
1603 if (!aFmtCode || !bFmtCode)
1604 return op.emitOpError("unsupported element types for scaled_wmma");
1605
1606 // Get scale vector types and determine variant (scale vs scale16).
1607 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1608 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1609
1610 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1611 return op.emitOpError("scaleA and scaleB must have equal vector length");
1612
1613 // Extract scale format from element types.
1614 Type scaleAElemType = scaleAVecType.getElementType();
1615 Type scaleBElemType = scaleBVecType.getElementType();
1616
1617 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1618 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1619
1620 if (!scaleAFmt || !scaleBFmt)
1621 return op.emitOpError("unsupported scale element types");
1622
1623 // Determine which intrinsic to use based on dimensions.
1624 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1625 std::optional<StringRef> intrinsicName =
1626 getScaledWmmaIntrinsicName(m, n, k, isScale16);
1627 if (!intrinsicName)
1628 return op.emitOpError("unsupported scaled_wmma dimensions: ")
1629 << m << "x" << n << "x" << k;
1630
1631 SmallVector<NamedAttribute, 8> attrs;
1632
1633 // The f4 variant does not have fmtA and fmtB attributes.
1634 bool is32x16 = (m == 32 && n == 16 && k == 128);
1635 if (!is32x16) {
1636 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1637 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1638 }
1639
1640 // modC uses default value of 0.
1641 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1642
1643 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1644 // half of the wave that is being selected (0 or 1).
1645 attrs.emplace_back(
1646 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1647 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1648 attrs.emplace_back(
1649 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1650 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1651
1652 // Reuse flags use default value of false.
1653 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1654 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1655
1656 // Convert typed float vectors to packed format.
1657 Value sourceA =
1658 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1659 Value sourceB =
1660 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1661
1662 // Pack scale vectors into i32/i64.
1663 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1664 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1665
1666 // Create the intrinsic call.
1667 OperationState loweredOp(loc, *intrinsicName);
1668 loweredOp.addTypes(outType);
1669 loweredOp.addOperands(
1670 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1671 loweredOp.addAttributes(attrs);
1672
1673 Operation *lowered = rewriter.create(loweredOp);
1674 rewriter.replaceOp(op, lowered->getResults());
1675
1676 return success();
1677 }
1678};
1679
1680struct TransposeLoadOpLowering
1681 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1682 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1683 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1684
1685 Chipset chipset;
1686
1687 LogicalResult
1688 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1689 ConversionPatternRewriter &rewriter) const override {
1690 if (chipset != kGfx950)
1691 return op.emitOpError("Non-gfx950 chipset not supported");
1692
1693 Location loc = op.getLoc();
1694 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1695
1696 // Elements in subbyte memrefs are stored non-contiguously,
1697 // reject if source is sub-byte memref. Use emulated memrefs instead.
1698 size_t srcElementSize =
1699 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1700 if (srcElementSize < 8)
1701 return op.emitOpError("Expect source memref to have at least 8 bits "
1702 "element size, got ")
1703 << srcElementSize;
1704
1705 auto resultType = cast<VectorType>(op.getResult().getType());
1706 Value srcPtr =
1707 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1708 (adaptor.getSrcIndices()));
1709
1710 size_t numElements = resultType.getNumElements();
1711 size_t elementTypeSize =
1712 resultType.getElementType().getIntOrFloatBitWidth();
1713
1714 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1715 // the element size is smaller than 16 bits.
1716 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1717 rewriter.getIntegerType(32));
1718 Type llvmResultType = typeConverter->convertType(resultType);
1719
1720 switch (elementTypeSize) {
1721 case 4: {
1722 assert(numElements == 16);
1723 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1724 rocdlResultType, srcPtr);
1725 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1726 break;
1727 }
1728 case 6: {
1729 assert(numElements == 16);
1730 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1731 rocdlResultType, srcPtr);
1732 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1733 break;
1734 }
1735 case 8: {
1736 assert(numElements == 8);
1737 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1738 rocdlResultType, srcPtr);
1739 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1740 break;
1741 }
1742 case 16: {
1743 assert(numElements == 4);
1744 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1745 srcPtr);
1746 break;
1747 }
1748 default:
1749 return op.emitOpError("Unsupported element size for transpose load");
1750 }
1751 return success();
1752 }
1753};
1754
1755struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1756 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1757 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1758
1759 Chipset chipset;
1760
1761 LogicalResult
1762 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1763 ConversionPatternRewriter &rewriter) const override {
1764 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1765 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1766
1767 Location loc = op.getLoc();
1768
1769 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1770 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1771
1772 // TODO: instead of only transfering one element per thread, we could
1773 // augment it to transfer multiple elements per thread by issuing multiple
1774 // `global_load_lds` instructions.
1775 Type transferType = op.getTransferType();
1776 int loadWidth = [&]() -> int {
1777 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1778 return (transferVectorType.getNumElements() *
1779 transferVectorType.getElementTypeBitWidth()) /
1780 8;
1781 }
1782 return transferType.getIntOrFloatBitWidth() / 8;
1783 }();
1784
1785 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1786 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1787 return op.emitOpError("chipset unsupported element size");
1788
1789 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1790 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1791 "16-byte load widths are only supported on gfx950");
1792
1793 Value srcPtr =
1794 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1795 (adaptor.getSrcIndices()));
1796 Value dstPtr =
1797 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1798 (adaptor.getDstIndices()));
1799
1800 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1801 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1802 /*offset=*/rewriter.getI32IntegerAttr(0),
1803 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1804 ArrayAttr{});
1805
1806 return success();
1807 }
1808};
1809
1810namespace {
1811struct ExtPackedFp8OpLowering final
1812 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1813 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1814 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1815 chipset(chipset) {}
1816 Chipset chipset;
1817
1818 LogicalResult
1819 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1820 ConversionPatternRewriter &rewriter) const override;
1821};
1822
1823struct ScaledExtPackedMatrixOpLowering final
1824 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1825 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1826 Chipset chipset)
1827 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1828 chipset(chipset) {}
1829 Chipset chipset;
1830
1831 LogicalResult
1832 matchAndRewrite(ScaledExtPackedMatrixOp op,
1833 ScaledExtPackedMatrixOpAdaptor adaptor,
1834 ConversionPatternRewriter &rewriter) const override;
1835};
1836
1837struct PackedTrunc2xFp8OpLowering final
1838 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1839 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1840 Chipset chipset)
1841 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1842 chipset(chipset) {}
1843 Chipset chipset;
1844
1845 LogicalResult
1846 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1847 ConversionPatternRewriter &rewriter) const override;
1848};
1849
1850struct PackedStochRoundFp8OpLowering final
1851 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1852 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1853 Chipset chipset)
1854 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1855 chipset(chipset) {}
1856 Chipset chipset;
1857
1858 LogicalResult
1859 matchAndRewrite(PackedStochRoundFp8Op op,
1860 PackedStochRoundFp8OpAdaptor adaptor,
1861 ConversionPatternRewriter &rewriter) const override;
1862};
1863
1864struct ScaledExtPackedOpLowering final
1865 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1866 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1867 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1868 chipset(chipset) {}
1869 Chipset chipset;
1870
1871 LogicalResult
1872 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1873 ConversionPatternRewriter &rewriter) const override;
1874};
1875
1876struct PackedScaledTruncOpLowering final
1877 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1878 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1879 Chipset chipset)
1880 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1881 chipset(chipset) {}
1882 Chipset chipset;
1883
1884 LogicalResult
1885 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1886 ConversionPatternRewriter &rewriter) const override;
1887};
1888
1889} // end namespace
1890
1891LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1892 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1893 ConversionPatternRewriter &rewriter) const {
1894 Location loc = op.getLoc();
1895 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1896 return rewriter.notifyMatchFailure(
1897 loc, "Fp8 conversion instructions are not available on target "
1898 "architecture and their emulation is not implemented");
1899 Type v4i8 =
1900 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1901 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1902 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1903
1904 Value source = adaptor.getSource();
1905 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1906 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1907 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1908 // Extend to a v4i8
1909 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1910 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1911 if (!sourceVecType) {
1912 longVec = LLVM::InsertElementOp::create(
1913 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1914 } else {
1915 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1916 Value idx = createI32Constant(rewriter, loc, i);
1917 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1918 longVec =
1919 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1920 }
1921 }
1922 source = longVec;
1923 }
1924 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1925 if (resultVecType) {
1926 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1927 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1928 op.getIndex());
1929 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1930 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1931 op.getIndex());
1932 }
1933 } else {
1934 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1935 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1936 op.getIndex());
1937 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1938 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1939 op.getIndex());
1940 }
1941 }
1942 return success();
1943}
1944
1945int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1946 int32_t firstScaleByte) {
1947 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1948 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
1949 // firstScaleByte are merged into a single attribute scaleSel. This is how
1950 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1951 // attribute but is derifed from firstScaleLane).
1952 assert(llvm::is_contained({16, 32}, blockSize));
1953 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1954
1955 const bool isFp8 = bitWidth == 8;
1956 const bool isBlock16 = blockSize == 16;
1957
1958 if (!isFp8) {
1959 int32_t bit0 = isBlock16;
1960 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1961 int32_t bit1 = (firstScaleByte == 2) << 1;
1962 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1963 int32_t bit2 = scaleWaveHalf << 2;
1964 return bit2 | bit1 | bit0;
1965 }
1966
1967 int32_t bit0 = isBlock16;
1968 // firstScaleByte is guaranteed to be defined by two bits.
1969 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1970 int32_t bits2and1 = firstScaleByte << 1;
1971 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1972 int32_t bit3 = scaleWaveHalf << 3;
1973 int32_t bits = bit3 | bits2and1 | bit0;
1974 // These are invalid cases.
1975 assert(!llvm::is_contained(
1976 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1977 return bits;
1978}
1979
1980static std::optional<StringRef>
1981scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1982 using fp4 = Float4E2M1FNType;
1983 using fp8 = Float8E4M3FNType;
1984 using bf8 = Float8E5M2Type;
1985 using fp6 = Float6E2M3FNType;
1986 using bf6 = Float6E3M2FNType;
1987 if (isa<fp4>(srcElemType)) {
1988 if (destElemType.isF16())
1989 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1990 if (destElemType.isBF16())
1991 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1992 if (destElemType.isF32())
1993 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1994 return std::nullopt;
1995 }
1996 if (isa<fp8>(srcElemType)) {
1997 if (destElemType.isF16())
1998 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1999 if (destElemType.isBF16())
2000 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2001 if (destElemType.isF32())
2002 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2003 return std::nullopt;
2004 }
2005 if (isa<bf8>(srcElemType)) {
2006 if (destElemType.isF16())
2007 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2008 if (destElemType.isBF16())
2009 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2010 if (destElemType.isF32())
2011 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2012 return std::nullopt;
2013 }
2014 if (isa<fp6>(srcElemType)) {
2015 if (destElemType.isF16())
2016 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2017 if (destElemType.isBF16())
2018 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2019 if (destElemType.isF32())
2020 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2021 return std::nullopt;
2022 }
2023 if (isa<bf6>(srcElemType)) {
2024 if (destElemType.isF16())
2025 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2026 if (destElemType.isBF16())
2027 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2028 if (destElemType.isF32())
2029 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2030 return std::nullopt;
2031 }
2032 llvm_unreachable("invalid combination of element types for packed conversion "
2033 "instructions");
2034}
2035
2036LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2037 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2038 ConversionPatternRewriter &rewriter) const {
2039 using fp4 = Float4E2M1FNType;
2040 using fp8 = Float8E4M3FNType;
2041 using bf8 = Float8E5M2Type;
2042 using fp6 = Float6E2M3FNType;
2043 using bf6 = Float6E3M2FNType;
2044 Location loc = op.getLoc();
2045 if (chipset != kGfx1250) {
2046 return rewriter.notifyMatchFailure(
2047 loc,
2048 "Scaled fp packed conversion instructions are not available on target "
2049 "architecture and their emulation is not implemented");
2050 }
2051 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2052 // is being selected.
2053 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2054 int32_t firstScaleByte = op.getFirstScaleByte();
2055 int32_t blockSize = op.getBlockSize();
2056 auto sourceType = cast<VectorType>(op.getSource().getType());
2057 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2058 unsigned bitWidth = srcElemType.getWidth();
2059
2060 auto targetType = cast<VectorType>(op.getResult().getType());
2061 auto destElemType = cast<FloatType>(targetType.getElementType());
2062
2063 IntegerType i32 = rewriter.getI32Type();
2064 Value source = adaptor.getSource();
2065 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2066 Type packedType = nullptr;
2067 if (isa<fp4>(srcElemType)) {
2068 packedType = i32;
2069 packedType = getTypeConverter()->convertType(packedType);
2070 } else if (isa<fp8, bf8>(srcElemType)) {
2071 packedType = VectorType::get(2, i32);
2072 packedType = getTypeConverter()->convertType(packedType);
2073 } else if (isa<fp6, bf6>(srcElemType)) {
2074 packedType = VectorType::get(3, i32);
2075 packedType = getTypeConverter()->convertType(packedType);
2076 } else {
2077 llvm_unreachable("invalid element type for packed scaled ext");
2078 }
2079
2080 if (!packedType || !llvmResultType) {
2081 return rewriter.notifyMatchFailure(op, "type conversion failed");
2082 }
2083
2084 std::optional<StringRef> maybeIntrinsic =
2085 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2086 if (!maybeIntrinsic.has_value())
2087 return op.emitOpError(
2088 "no intrinsic matching packed scaled conversion on the given chipset");
2089
2090 int32_t scaleSel =
2091 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2092 Value castedScale =
2093 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2094 Value castedSource =
2095 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2096
2097 OperationState loweredOp(loc, *maybeIntrinsic);
2098 loweredOp.addTypes({llvmResultType});
2099 loweredOp.addOperands({castedSource, castedScale});
2100
2101 SmallVector<NamedAttribute, 1> attrs;
2102 attrs.push_back(
2103 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2104
2105 loweredOp.addAttributes(attrs);
2106 Operation *lowered = rewriter.create(loweredOp);
2107 rewriter.replaceOp(op, lowered);
2108
2109 return success();
2110}
2111
2112LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2113 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2114 ConversionPatternRewriter &rewriter) const {
2115 Location loc = op.getLoc();
2116 if (chipset != kGfx950)
2117 return rewriter.notifyMatchFailure(
2118 loc, "Scaled fp conversion instructions are not available on target "
2119 "architecture and their emulation is not implemented");
2120 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2121
2122 Value source = adaptor.getSource();
2123 Value scale = adaptor.getScale();
2124
2125 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2126 Type sourceElemType = sourceVecType.getElementType();
2127 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2128 Type destElemType = destVecType.getElementType();
2129
2130 VectorType packedVecType;
2131 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2132 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2133 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2134 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2135 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2136 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2137 } else {
2138 llvm_unreachable("invalid element type for scaled ext");
2139 }
2140
2141 // Extend to a packedVectorType
2142 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2143 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2144 if (!sourceVecType) {
2145 longVec = LLVM::InsertElementOp::create(
2146 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2147 } else {
2148 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2149 Value idx = createI32Constant(rewriter, loc, i);
2150 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2151 longVec =
2152 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2153 }
2154 }
2155 source = longVec;
2156 }
2157 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2158
2159 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2160 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2161 op, destVecType, i32Source, scale, op.getIndex());
2162 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2163 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2164 op, destVecType, i32Source, scale, op.getIndex());
2165 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2166 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2167 op, destVecType, i32Source, scale, op.getIndex());
2168 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2169 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2170 op, destVecType, i32Source, scale, op.getIndex());
2171 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2172 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2173 op, destVecType, i32Source, scale, op.getIndex());
2174 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2175 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2176 op, destVecType, i32Source, scale, op.getIndex());
2177 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2178 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2179 op, destVecType, i32Source, scale, op.getIndex());
2180 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2181 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2182 op, destVecType, i32Source, scale, op.getIndex());
2183 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2184 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2185 op, destVecType, i32Source, scale, op.getIndex());
2186 else
2187 return failure();
2188
2189 return success();
2190}
2191
2192LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2193 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2194 ConversionPatternRewriter &rewriter) const {
2195 Location loc = op.getLoc();
2196 if (chipset != kGfx950)
2197 return rewriter.notifyMatchFailure(
2198 loc, "Scaled fp conversion instructions are not available on target "
2199 "architecture and their emulation is not implemented");
2200 Type v2i16 = getTypeConverter()->convertType(
2201 VectorType::get(2, rewriter.getI16Type()));
2202 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2203
2204 Type resultType = op.getResult().getType();
2205 Type resultElemType = getElementTypeOrSelf(resultType);
2206 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2207 Type sourceElemType = sourceVecType.getElementType();
2208
2209 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2210
2211 Value source = adaptor.getSource();
2212 Value scale = adaptor.getScale();
2213 Value existing = adaptor.getExisting();
2214 if (existing)
2215 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2216 else
2217 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2218
2219 if (sourceVecType.getNumElements() < 2) {
2220 Value c0 = createI32Constant(rewriter, loc, 0);
2221 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2222 VectorType v2 = VectorType::get(2, sourceElemType);
2223 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2224 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2225 }
2226
2227 Value sourceA, sourceB;
2228 if (sourceElemType.isF32()) {
2229 Value c0 = createI32Constant(rewriter, loc, 0);
2230 Value c1 = createI32Constant(rewriter, loc, 1);
2231 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2232 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2233 }
2234
2235 Value result;
2236 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2237 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2238 existing, sourceA, sourceB,
2239 scale, op.getIndex());
2240 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2241 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2242 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2243 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2244 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2245 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2246 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2247 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2248 existing, sourceA, sourceB,
2249 scale, op.getIndex());
2250 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2251 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2252 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2253 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2254 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2255 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2256 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2257 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2258 existing, sourceA, sourceB,
2259 scale, op.getIndex());
2260 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2261 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2262 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2263 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2264 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2265 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2266 else
2267 return failure();
2268
2269 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2270 op, getTypeConverter()->convertType(resultType), result);
2271 return success();
2272}
2273
2274LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2275 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2276 ConversionPatternRewriter &rewriter) const {
2277 Location loc = op.getLoc();
2278 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2279 return rewriter.notifyMatchFailure(
2280 loc, "Fp8 conversion instructions are not available on target "
2281 "architecture and their emulation is not implemented");
2282 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2283
2284 Type resultType = op.getResult().getType();
2285 Type resultElemType = getElementTypeOrSelf(resultType);
2286
2287 Value sourceA = adaptor.getSourceA();
2288 Value sourceB = adaptor.getSourceB();
2289 if (!sourceB)
2290 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2291 Value existing = adaptor.getExisting();
2292 if (existing)
2293 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2294 else
2295 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2296
2297 Value result;
2298 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2299 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2300 existing, op.getWordIndex());
2301 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2302 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2303 existing, op.getWordIndex());
2304
2305 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2306 op, getTypeConverter()->convertType(resultType), result);
2307 return success();
2308}
2309
2310LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2311 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2312 ConversionPatternRewriter &rewriter) const {
2313 Location loc = op.getLoc();
2314 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2315 return rewriter.notifyMatchFailure(
2316 loc, "Fp8 conversion instructions are not available on target "
2317 "architecture and their emulation is not implemented");
2318 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2319
2320 Type resultType = op.getResult().getType();
2321 Type resultElemType = getElementTypeOrSelf(resultType);
2322
2323 Value source = adaptor.getSource();
2324 Value stoch = adaptor.getStochiasticParam();
2325 Value existing = adaptor.getExisting();
2326 if (existing)
2327 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2328 else
2329 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2330
2331 Value result;
2332 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2333 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2334 existing, op.getStoreIndex());
2335 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2336 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2337 existing, op.getStoreIndex());
2338
2339 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2340 op, getTypeConverter()->convertType(resultType), result);
2341 return success();
2342}
2343
2344// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2345// operation into the corresponding ROCDL instructions.
2346struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2347 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2348 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2349 Chipset chipset;
2350
2351 LogicalResult
2352 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2353 ConversionPatternRewriter &rewriter) const override {
2354
2355 // Convert the source operand to the corresponding LLVM type
2356 Location loc = DppOp.getLoc();
2357 Value src = adaptor.getSrc();
2358 Value old = adaptor.getOld();
2359 Type srcType = src.getType();
2360 Type oldType = old.getType();
2361 Type llvmType = nullptr;
2362 if (srcType.getIntOrFloatBitWidth() < 32) {
2363 llvmType = rewriter.getI32Type();
2364 } else if (isa<FloatType>(srcType)) {
2365 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2366 ? rewriter.getF32Type()
2367 : rewriter.getF64Type();
2368 } else if (isa<IntegerType>(srcType)) {
2369 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2370 ? rewriter.getI32Type()
2371 : rewriter.getI64Type();
2372 }
2373 auto llvmSrcIntType = typeConverter->convertType(
2374 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2375
2376 // If the source type is less of 32, use bitcast to convert it to i32.
2377 auto convertOperand = [&](Value operand, Type operandType) {
2378 if (operandType.getIntOrFloatBitWidth() <= 16) {
2379 if (llvm::isa<FloatType>(operandType)) {
2380 operand =
2381 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2382 }
2383 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2384 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2385 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2386 operand =
2387 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2388 createI32Constant(rewriter, loc, 0));
2389 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2390 }
2391 return operand;
2392 };
2393
2394 src = convertOperand(src, srcType);
2395 old = convertOperand(old, oldType);
2396
2397 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2398 enum DppCtrl : unsigned {
2399 ROW_SHL0 = 0x100,
2400 ROW_SHR0 = 0x110,
2401 ROW_ROR0 = 0x120,
2402 WAVE_SHL1 = 0x130,
2403 WAVE_ROL1 = 0x134,
2404 WAVE_SHR1 = 0x138,
2405 WAVE_ROR1 = 0x13C,
2406 ROW_MIRROR = 0x140,
2407 ROW_HALF_MIRROR = 0x141,
2408 BCAST15 = 0x142,
2409 BCAST31 = 0x143,
2410 };
2411
2412 auto kind = DppOp.getKind();
2413 auto permArgument = DppOp.getPermArgument();
2414 uint32_t DppCtrl = 0;
2415
2416 switch (kind) {
2417
2418 case DPPPerm::quad_perm: {
2419 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2420 int32_t i = 0;
2421 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2422 uint32_t num = elem.getInt();
2423 DppCtrl |= num << (i * 2);
2424 i++;
2425 }
2426 break;
2427 }
2428 case DPPPerm::row_shl: {
2429 auto intAttr = cast<IntegerAttr>(*permArgument);
2430 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2431 break;
2432 }
2433 case DPPPerm::row_shr: {
2434 auto intAttr = cast<IntegerAttr>(*permArgument);
2435 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2436 break;
2437 }
2438 case DPPPerm::row_ror: {
2439 auto intAttr = cast<IntegerAttr>(*permArgument);
2440 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2441 break;
2442 }
2443 case DPPPerm::wave_shl:
2444 DppCtrl = DppCtrl::WAVE_SHL1;
2445 break;
2446 case DPPPerm::wave_shr:
2447 DppCtrl = DppCtrl::WAVE_SHR1;
2448 break;
2449 case DPPPerm::wave_rol:
2450 DppCtrl = DppCtrl::WAVE_ROL1;
2451 break;
2452 case DPPPerm::wave_ror:
2453 DppCtrl = DppCtrl::WAVE_ROR1;
2454 break;
2455 case DPPPerm::row_mirror:
2456 DppCtrl = DppCtrl::ROW_MIRROR;
2457 break;
2458 case DPPPerm::row_half_mirror:
2459 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2460 break;
2461 case DPPPerm::row_bcast_15:
2462 DppCtrl = DppCtrl::BCAST15;
2463 break;
2464 case DPPPerm::row_bcast_31:
2465 DppCtrl = DppCtrl::BCAST31;
2466 break;
2467 }
2468
2469 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2470 // constants
2471 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2472 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2473 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2474
2475 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2476 auto dppMovOp =
2477 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2478 rowMask, bankMask, boundCtrl);
2479
2480 Value result = dppMovOp.getRes();
2481 if (srcType.getIntOrFloatBitWidth() < 32) {
2482 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2483 if (!llvm::isa<IntegerType>(srcType)) {
2484 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2485 }
2486 }
2487
2488 // We are replacing the AMDGPU_DPPOp instruction with the new
2489 // ROCDL_DPPMovOp instruction
2490 rewriter.replaceOp(DppOp, ValueRange(result));
2491 return success();
2492 }
2493};
2494
2495struct AMDGPUSwizzleBitModeLowering
2496 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2498
2499 LogicalResult
2500 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2501 ConversionPatternRewriter &rewriter) const override {
2502 Location loc = op.getLoc();
2503 Type i32 = rewriter.getI32Type();
2504 Value src = adaptor.getSrc();
2505 SmallVector<Value> decomposed =
2506 LLVM::decomposeValue(rewriter, loc, src, i32);
2507 unsigned andMask = op.getAndMask();
2508 unsigned orMask = op.getOrMask();
2509 unsigned xorMask = op.getXorMask();
2510
2511 // bit 15 is 0 for the BitMode swizzle.
2512 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2513 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2514 Value maskValue = createI32Constant(rewriter, loc, mask);
2515 SmallVector<Value> swizzled;
2516 for (Value v : decomposed) {
2517 Value res =
2518 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2519 swizzled.emplace_back(res);
2520 }
2521
2522 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2523 rewriter.replaceOp(op, result);
2524 return success();
2525 }
2526};
2527
2528struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2530
2531 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2532 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2533 Chipset chipset;
2534
2535 LogicalResult
2536 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2537 ConversionPatternRewriter &rewriter) const override {
2538 if (chipset < kGfx950)
2539 return op->emitOpError("permlane_swap is only supported on gfx950+");
2540
2541 Location loc = op.getLoc();
2542 Type i32 = rewriter.getI32Type();
2543 Value src = adaptor.getSrc();
2544 unsigned rowLength = op.getRowLength();
2545 bool fi = op.getFetchInactive();
2546 bool boundctrl = op.getBoundCtrl();
2547
2548 SmallVector<Value> decomposed =
2549 LLVM::decomposeValue(rewriter, loc, src, i32);
2550
2551 SmallVector<Value> permuted;
2552 for (Value v : decomposed) {
2553 Value res;
2554 Type i32pair = LLVM::LLVMStructType::getLiteral(
2555 rewriter.getContext(), {v.getType(), v.getType()});
2556
2557 if (rowLength == 16)
2558 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2559 boundctrl);
2560 else if (rowLength == 32)
2561 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2562 boundctrl);
2563 else
2564 llvm_unreachable("unsupported row length");
2565
2566 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2567 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2568
2569 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2570 LLVM::ICmpPredicate::eq, vdst0, v);
2571
2572 // Per `permlane(16|32)` semantics: if the first extracted element equals
2573 // 'v', the result is the second element; otherwise it is the first.
2574 Value vdstNew =
2575 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2576 permuted.emplace_back(vdstNew);
2577 }
2578
2579 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2580 rewriter.replaceOp(op, result);
2581 return success();
2582 }
2583};
2584
2585static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2586 Value accumulator, Value value, int64_t shift) {
2587 shift = shift % 32;
2588 Value shiftAmount;
2589 if (shift != 0) {
2590 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2591 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2592 }
2593
2594 if (matchPattern(accumulator, mlir::m_Zero()))
2595 return value;
2596
2597 constexpr bool isDisjoint = true;
2598 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2599}
2600
2601template <typename BaseOp>
2602struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
2603 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2604 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
2605
2606 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2607 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2608 Chipset chipset;
2609
2610 LogicalResult
2611 matchAndRewrite(BaseOp op, Adaptor adaptor,
2612 ConversionPatternRewriter &rewriter) const override {
2613 if (chipset < kGfx1250)
2614 return op->emitOpError("make_dma_base is only supported on gfx1250");
2615
2616 Location loc = op.getLoc();
2617
2618 constexpr int32_t constlen = 4;
2619 Value consts[constlen];
2620 for (int64_t i = 0; i < constlen; ++i)
2621 consts[i] = createI32Constant(rewriter, loc, i);
2622
2623 constexpr int32_t sgprslen = constlen;
2624 Value sgprs[sgprslen];
2625 for (int64_t i = 0; i < sgprslen; ++i) {
2626 sgprs[i] = consts[0];
2627 }
2628
2629 sgprs[0] = consts[1];
2630
2631 if constexpr (BaseOp::isGather()) {
2632 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2633
2634 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2635 Type indexType = type.getIndexType();
2636 unsigned indexSize = indexType.getIntOrFloatBitWidth();
2637 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2638 "expected index_size to be 16 or 32");
2639 unsigned idx = (indexSize / 16) - 1;
2640
2641 if (idx)
2642 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2643 }
2644
2645 ValueRange ldsIndices = adaptor.getLdsIndices();
2646 Value lds = adaptor.getLds();
2647 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2648
2650 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2651
2652 ValueRange globalIndices = adaptor.getGlobalIndices();
2653 Value global = adaptor.getGlobal();
2654 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2655
2657 rewriter, loc, globalMemRefType, global, globalIndices);
2658
2659 Type i32 = rewriter.getI32Type();
2660 Type i64 = rewriter.getI64Type();
2661
2662 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2663 Value castForGlobalAddr =
2664 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2665
2666 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2667
2668 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2669 createI64Constant(rewriter, loc, 32));
2670
2671 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2672
2673 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2674 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2675
2676 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2677
2678 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2679 assert(v4i32 && "expected type conversion to succeed");
2680 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2681
2682 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2683 result =
2684 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
2685
2686 rewriter.replaceOp(op, result);
2687 return success();
2688 }
2689};
2690
2691template <typename DescriptorOp>
2692struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
2693 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2694 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
2695
2696 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
2697 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2698 Chipset chipset;
2699
2700 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
2701
2702 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2703 ConversionPatternRewriter &rewriter, Location loc,
2704 Value sgpr0) const {
2705 Value mask = op.getWorkgroupMask();
2706 if (!mask)
2707 return sgpr0;
2708
2709 Type i16 = rewriter.getI16Type();
2710 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2711 Type i32 = rewriter.getI32Type();
2712 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2713 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2714 }
2715
2716 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2717 ConversionPatternRewriter &rewriter, Location loc,
2718 Value sgpr0, ArrayRef<Value> consts) const {
2719 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2720 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2721 "expected type width to be 8, 16, 32, or 64.");
2722 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
2723 Value size = consts[idx];
2724 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2725 }
2726
2727 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
2728 ConversionPatternRewriter &rewriter, Location loc,
2729 Value sgpr0, ArrayRef<Value> consts) const {
2730 if (!adaptor.getAtomicBarrierAddress())
2731 return sgpr0;
2732
2733 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2734 }
2735
2736 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
2737 ConversionPatternRewriter &rewriter, Location loc,
2738 Value sgpr0, ArrayRef<Value> consts) const {
2739 if (!adaptor.getGlobalIncrement())
2740 return sgpr0;
2741
2742 // Value is ignored when in gather mode.
2743 // TODO: emit error earlier?
2744 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2745 }
2746
2747 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
2748 ConversionPatternRewriter &rewriter, Location loc,
2749 Value sgpr0, ArrayRef<Value> consts) const {
2750 if (!op.getPadAmount())
2751 return sgpr0;
2752
2753 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2754 }
2755
2756 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
2757 ConversionPatternRewriter &rewriter, Location loc,
2758 Value sgpr0, ArrayRef<Value> consts) const {
2759 if (!op.getWorkgroupMask())
2760 return sgpr0;
2761
2762 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2763 }
2764
2765 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
2766 ConversionPatternRewriter &rewriter, Location loc,
2767 Value sgpr0, ArrayRef<Value> consts) const {
2768 if (!op.getPadAmount())
2769 return sgpr0;
2770
2771 // pre-condition: padInterval can be a power of two between 2 and 256.
2772 // TODO: Validation if the value breaks the pre-condition.
2773 // If the pre-condition fails, there is a possibility of
2774 // affecting the higher bits. In a following PR implement
2775 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2776 // checked at runtime.
2777 IntegerType i32 = rewriter.getI32Type();
2778 Value padInterval = adaptor.getPadInterval();
2779 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2780 padInterval, false);
2781 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2782 // post-condition: padInterval can be a value between 0 and 7.
2783 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2784 }
2785
2786 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
2787 ConversionPatternRewriter &rewriter, Location loc,
2788 Value sgpr0, ArrayRef<Value> consts) const {
2789 if (!op.getPadAmount())
2790 return sgpr0;
2791
2792 // pre-condition: padAmount is a value between 1-128.
2793 // TODO: Validation if the value breaks the pre-condition.
2794 // If the pre-condition fails, there is a possibility of
2795 // affecting the higher bits. In a following PR implement
2796 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2797 // checked at runtime.
2798 Value padAmount = adaptor.getPadAmount();
2799 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2800 // post-condition: padAmount is a value between 0-127.
2801 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2802 }
2803
2804 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
2805 ConversionPatternRewriter &rewriter,
2806 Location loc, Value sgpr1,
2807 ArrayRef<Value> consts) const {
2808 if (!adaptor.getAtomicBarrierAddress())
2809 return sgpr1;
2810
2811 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2812 auto barrierAddressTy =
2813 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2814 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2815 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
2816 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
2817 atomicBarrierIndices);
2818 IntegerType i32 = rewriter.getI32Type();
2819 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
2820 // that the 3 LSBs are zero.
2821 // TODO: Validation if the value breaks the pre-condition.
2822 // In a following PR implement RuntimeVerifiableOpInterface
2823 // that instruments conditions that need to be checked at runtime.
2824 atomicBarrierAddress =
2825 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2826 atomicBarrierAddress =
2827 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2828 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
2829 atomicBarrierAddress =
2830 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2831 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2832 }
2833
2834 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
2835 ConversionPatternRewriter &rewriter,
2836 Location loc, Value sgpr1, Value sgpr2,
2837 ArrayRef<Value> consts, uint64_t dimX,
2838 uint32_t offset) const {
2839 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
2840 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
2841 SmallVector<OpFoldResult> mixedGlobalSizes =
2842 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
2843 if (mixedGlobalSizes.size() <= dimX)
2844 return {sgpr1, sgpr2};
2845
2846 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
2847 // pre-condition: tensorDimX is less than 2^32-1
2848 // TODO: Validation if the value breaks the pre-condition.
2849 // In a following PR implement RuntimeVerifiableOpInterface that instruments
2850 // conditions that need to be checked at runtime. This could also be fixed
2851 // by saying that mixedGlobalSizes is a DynamicI32List.
2852 Value tensorDimX;
2853 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
2854 tensorDimX =
2855 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2856 } else {
2857 IntegerType i32 = rewriter.getI32Type();
2858 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
2859 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
2860 }
2861
2862 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
2863
2864 Value c16 = createI32Constant(rewriter, loc, 16);
2865 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
2866 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
2867 return {sgpr1, sgpr2};
2868 }
2869
2870 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
2871 ConversionPatternRewriter &rewriter,
2872 Location loc, Value sgpr1, Value sgpr2,
2873 ArrayRef<Value> consts) const {
2874 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
2875 48);
2876 }
2877
2878 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
2879 ConversionPatternRewriter &rewriter,
2880 Location loc, Value sgpr2, Value sgpr3,
2881 ArrayRef<Value> consts) const {
2882 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
2883 80);
2884 }
2885
2886 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
2887 ConversionPatternRewriter &rewriter, Location loc,
2888 Value sgpr, ArrayRef<Value> consts, size_t dimX,
2889 int64_t offset) const {
2890 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
2891 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
2892 SmallVector<OpFoldResult> mixedSharedSizes =
2893 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
2894 if (mixedSharedSizes.size() <= dimX)
2895 return sgpr;
2896
2897 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2898 // pre-condition: tileDimX is less than 2^16-1
2899 // TODO: Validation if the value breaks the pre-condition.
2900 // If the pre-condition fails, there is a possibility of
2901 // affecting the higher bits. In a following PR implement
2902 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2903 // checked at runtime. This could also be fixed by saying that
2904 // mixedSharedSizes is a DynamicI16List.
2905 Value tileDimX;
2906 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
2907 tileDimX =
2908 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2909 } else {
2910 IntegerType i32 = rewriter.getI32Type();
2911 tileDimX = cast<Value>(tileDimXOpFoldResult);
2912 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
2913 }
2914
2915 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2916 }
2917
2918 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
2919 ConversionPatternRewriter &rewriter, Location loc,
2920 Value sgpr3, ArrayRef<Value> consts) const {
2921 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2922 }
2923
2924 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
2925 ConversionPatternRewriter &rewriter, Location loc,
2926 Value sgpr4, ArrayRef<Value> consts) const {
2927 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2928 }
2929
2930 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
2931 ConversionPatternRewriter &rewriter, Location loc,
2932 Value sgpr4, ArrayRef<Value> consts) const {
2933 auto type = cast<VectorType>(op.getIndices().getType());
2934 ArrayRef<int64_t> shape = type.getShape();
2935 assert(shape.size() == 1 && "expected shape to be of rank 1.");
2936 unsigned length = shape.back();
2937 assert(0 < length && length <= 16 && "expected length to be at most 16.");
2938 Value value = createI32Constant(rewriter, loc, length);
2939 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
2940 }
2941
2942 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
2943 ConversionPatternRewriter &rewriter,
2944 Location loc, Value sgpr4,
2945 ArrayRef<Value> consts) const {
2946 if constexpr (DescriptorOp::isGather())
2947 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
2948 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
2949 }
2950
2951 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
2952 ConversionPatternRewriter &rewriter, Location loc,
2953 Value sgpr4, ArrayRef<Value> consts) const {
2954 // Value is ignored when in gather mode.
2955 if constexpr (DescriptorOp::isGather())
2956 return sgpr4;
2957 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2958 }
2959
2960 std::pair<Value, Value>
2961 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
2962 ConversionPatternRewriter &rewriter, Location loc,
2963 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2964 size_t dimX, int64_t offset) const {
2965 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
2966 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
2967 SmallVector<OpFoldResult> mixedGlobalStrides =
2968 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
2969
2970 if (mixedGlobalStrides.size() <= dimX)
2971 return {sgprY, sgprZ};
2972
2973 OpFoldResult tensorDimXStrideOpFoldResult =
2974 *(mixedGlobalStrides.rbegin() + dimX);
2975 // pre-condition: tensorDimXStride is less than 2^48-1
2976 // TODO: Validation if the value breaks the pre-condition.
2977 // In a following PR implement RuntimeVerifiableOpInterface that instruments
2978 // conditions that need to be checked at runtime.
2979 Value tensorDimXStride;
2980 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
2981 tensorDimXStride =
2982 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2983 else
2984 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
2985
2986 constexpr int64_t first48bits = (1ll << 48) - 1;
2987 Value mask = createI64Constant(rewriter, loc, first48bits);
2988 tensorDimXStride =
2989 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
2990 IntegerType i32 = rewriter.getI32Type();
2991 Value tensorDimXStrideLow =
2992 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
2993 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
2994
2995 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
2996 Value shiftVal = createI64Constant(rewriter, loc, shift);
2997 Value tensorDimXStrideHigh =
2998 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
2999 tensorDimXStrideHigh =
3000 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3001 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3002 offset + shift);
3003 return {sgprY, sgprZ};
3004 }
3005
3006 std::pair<Value, Value>
3007 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3008 ConversionPatternRewriter &rewriter, Location loc,
3009 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3010 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3011 0, 160);
3012 }
3013
3014 std::pair<Value, Value>
3015 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3016 ConversionPatternRewriter &rewriter, Location loc,
3017 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3018 // Value is ignored when in gather mode.
3019 if constexpr (DescriptorOp::isGather())
3020 return {sgpr5, sgpr6};
3021 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3022 1, 208);
3023 }
3024
3025 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3026 ConversionPatternRewriter &rewriter, Location loc,
3027 ArrayRef<Value> consts) const {
3028 Value sgprs[8];
3029 for (int64_t i = 0; i < 8; ++i) {
3030 sgprs[i] = consts[0];
3031 }
3032
3033 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3034 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3035 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3036 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3037 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3038 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3039 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3040 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3041
3042 sgprs[1] =
3043 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3044 std::tie(sgprs[1], sgprs[2]) =
3045 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3046 std::tie(sgprs[2], sgprs[3]) =
3047 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3048
3049 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3050 sgprs[4] =
3051 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3052 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3053 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3054 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3055 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3056 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3057
3058 IntegerType i32 = rewriter.getI32Type();
3059 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3060 assert(v8i32 && "expected type conversion to succeed");
3061 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3062
3063 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3064 dgroup1 =
3065 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3066 }
3067
3068 return dgroup1;
3069 }
3070
3071 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3072 ConversionPatternRewriter &rewriter, Location loc,
3073 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3074 int64_t offset) const {
3075 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3076 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3077 SmallVector<OpFoldResult> mixedGlobalSizes =
3078 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3079 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3080 return sgpr0;
3081
3082 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3083 Value tensorDimX;
3084 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3085 tensorDimX =
3086 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3087 } else {
3088 IntegerType i32 = rewriter.getI32Type();
3089 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3090 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3091 }
3092
3093 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3094 }
3095
3096 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3097 ConversionPatternRewriter &rewriter, Location loc,
3098 Value sgpr0, ArrayRef<Value> consts) const {
3099 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3100 }
3101
3102 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3103 Location loc, Value accumulator,
3104 Value value, int64_t shift) const {
3105
3106 IntegerType i32 = rewriter.getI32Type();
3107 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3108 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3109 }
3110
3111 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3112 ConversionPatternRewriter &rewriter, Location loc,
3113 Value sgpr1, ArrayRef<Value> consts,
3114 int64_t offset) const {
3115 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3116 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3117 }
3118
3119 std::pair<Value, Value>
3120 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3121 ConversionPatternRewriter &rewriter, Location loc,
3122 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3123 int64_t offset) const {
3124 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3125 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3126 globalAddrIncrement, offset);
3127 Value shift = createI64Constant(rewriter, loc, 32);
3128 globalAddrIncrement =
3129 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3130 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3131 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3132 globalAddrIncrement, offset + 32);
3133 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3134 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3135 return {sgpr2, sgpr3};
3136 }
3137
3138 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3139 ConversionPatternRewriter &rewriter,
3140 Location loc, Value sgpr1,
3141 ArrayRef<Value> consts) const {
3142 Value ldsIncrement = op.getLdsIncrement();
3143 constexpr int64_t dim = 3;
3144 constexpr int64_t offset = 32;
3145 if (!ldsIncrement)
3146 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3147 offset);
3148 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3149 offset);
3150 }
3151
3152 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3153 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3154 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3155 Value globalIncrement = op.getGlobalIncrement();
3156 constexpr int32_t dim = 2;
3157 constexpr int32_t offset = 64;
3158 if (!globalIncrement)
3159 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3160 consts, dim, offset);
3161 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3162 consts, offset);
3163 }
3164
3165 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3166 ConversionPatternRewriter &rewriter, Location loc,
3167 Value sgpr3, ArrayRef<Value> consts,
3168 int32_t offset) const {
3169 Value iterationCount = adaptor.getIterationCount();
3170 IntegerType i32 = rewriter.getI32Type();
3171 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3172 // TODO: validation if the value breaks the pre-condition.
3173 // If the pre-condition fails, there is a possibility of
3174 // affecting the higher bits. In a following PR implement
3175 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3176 // checked at runtime.
3177 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3178 iterationCount =
3179 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3180 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3181 }
3182
3183 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3184 ConversionPatternRewriter &rewriter,
3185 Location loc, Value sgpr3,
3186 ArrayRef<Value> consts) const {
3187 Value iterateCount = op.getIterationCount();
3188 constexpr int32_t dim = 2;
3189 constexpr int32_t offset = 112;
3190 if (!iterateCount)
3191 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3192 offset);
3193
3194 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3195 }
3196
3197 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3198 ConversionPatternRewriter &rewriter, Location loc,
3199 ArrayRef<Value> consts) const {
3200 if constexpr (DescriptorOp::isGather())
3201 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3202 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3203 }
3204
3205 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3206 ConversionPatternRewriter &rewriter, Location loc,
3207 ArrayRef<Value> consts) const {
3208 IntegerType i32 = rewriter.getI32Type();
3209 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3210 assert(v4i32 && "expected type conversion to succeed.");
3211
3212 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3213 if (onlyNeedsTwoDescriptors)
3214 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3215
3216 constexpr int64_t sgprlen = 4;
3217 Value sgprs[sgprlen];
3218 for (int i = 0; i < sgprlen; ++i)
3219 sgprs[i] = consts[0];
3220
3221 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3222 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3223 sgprs[1], consts);
3224 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3225 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3226 sgprs[3] =
3227 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3228
3229 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3230 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3231 dgroup2 =
3232 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3233
3234 return dgroup2;
3235 }
3236
3237 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3238 ConversionPatternRewriter &rewriter, Location loc,
3239 ArrayRef<Value> consts, bool firstHalf) const {
3240 IntegerType i32 = rewriter.getI32Type();
3241 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3242 assert(v4i32 && "expected type conversion to succeed.");
3243
3244 Value indices = adaptor.getIndices();
3245 auto vectorType = cast<VectorType>(indices.getType());
3246 unsigned length = vectorType.getShape().back();
3247 Type elementType = vectorType.getElementType();
3248 unsigned maxLength = elementType == i32 ? 4 : 8;
3249 int32_t offset = firstHalf ? 0 : maxLength;
3250 unsigned discountedLength =
3251 std::max(static_cast<int32_t>(length - offset), 0);
3252
3253 unsigned targetSize = std::min(maxLength, discountedLength);
3254
3255 SmallVector<Value> indicesVector;
3256 for (unsigned i = offset; i < targetSize + offset; ++i) {
3257 Value idx;
3258 if (i < consts.size())
3259 idx = consts[i];
3260 else
3261 idx = createI32Constant(rewriter, loc, i);
3262 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
3263 indicesVector.push_back(elem);
3264 }
3265
3266 SmallVector<Value> indicesI32Vector;
3267 if (elementType == i32) {
3268 indicesI32Vector = indicesVector;
3269 } else {
3270 for (unsigned i = 0; i < targetSize; ++i) {
3271 Value index = indicesVector[i];
3272 indicesI32Vector.push_back(
3273 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3274 }
3275 if ((targetSize % 2) != 0)
3276 // Add padding when not divisible by two.
3277 indicesI32Vector.push_back(consts[0]);
3278 }
3279
3280 SmallVector<Value> indicesToInsert;
3281 if (elementType == i32) {
3282 indicesToInsert = indicesI32Vector;
3283 } else {
3284 unsigned size = indicesI32Vector.size() / 2;
3285 for (unsigned i = 0; i < size; ++i) {
3286 Value first = indicesI32Vector[2 * i];
3287 Value second = indicesI32Vector[2 * i + 1];
3288 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3289 indicesToInsert.push_back(joined);
3290 }
3291 }
3292
3293 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3294 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3295 dgroup =
3296 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3297
3298 return dgroup;
3299 }
3300
3301 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3302 ConversionPatternRewriter &rewriter, Location loc,
3303 ArrayRef<Value> consts) const {
3304 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
3305 }
3306
3307 std::pair<Value, Value>
3308 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3309 ConversionPatternRewriter &rewriter, Location loc,
3310 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
3311 constexpr int32_t dim = 3;
3312 constexpr int32_t offset = 0;
3313 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3314 dim, offset);
3315 }
3316
3317 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3318 ConversionPatternRewriter &rewriter,
3319 Location loc, Value sgpr1, Value sgpr2,
3320 ArrayRef<Value> consts) const {
3321 constexpr int32_t dim = 4;
3322 constexpr int32_t offset = 48;
3323 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3324 offset);
3325 }
3326
3327 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3328 ConversionPatternRewriter &rewriter, Location loc,
3329 Value sgpr2, ArrayRef<Value> consts) const {
3330 constexpr int32_t dim = 4;
3331 constexpr int32_t offset = 80;
3332 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3333 }
3334
3335 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3336 ConversionPatternRewriter &rewriter, Location loc,
3337 ArrayRef<Value> consts) const {
3338 if constexpr (DescriptorOp::isGather())
3339 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3340 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3341 }
3342
3343 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3344 ConversionPatternRewriter &rewriter, Location loc,
3345 ArrayRef<Value> consts) const {
3346 IntegerType i32 = rewriter.getI32Type();
3347 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3348 assert(v4i32 && "expected type conversion to succeed.");
3349 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3350 if (onlyNeedsTwoDescriptors)
3351 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3352
3353 constexpr int32_t sgprlen = 4;
3354 Value sgprs[sgprlen];
3355 for (int i = 0; i < sgprlen; ++i)
3356 sgprs[i] = consts[0];
3357
3358 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3359 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3360 std::tie(sgprs[1], sgprs[2]) =
3361 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3362 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3363
3364 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3365 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3366 dgroup3 =
3367 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3368
3369 return dgroup3;
3370 }
3371
3372 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3373 ConversionPatternRewriter &rewriter, Location loc,
3374 ArrayRef<Value> consts) const {
3375 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
3376 }
3377
3378 LogicalResult
3379 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3380 ConversionPatternRewriter &rewriter) const override {
3381 if (chipset < kGfx1250)
3382 return op->emitOpError(
3383 "make_dma_descriptor is only supported on gfx1250");
3384
3385 Location loc = op.getLoc();
3386
3387 SmallVector<Value> consts;
3388 for (int64_t i = 0; i < 8; ++i)
3389 consts.push_back(createI32Constant(rewriter, loc, i));
3390
3391 Value dgroup0 = this->getDGroup0(adaptor);
3392 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3393 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3394 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3395 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3396 rewriter.replaceOpWithMultiple(op, {results});
3397 return success();
3398 }
3399};
3400
3401template <typename SourceOp, typename TargetOp>
3402struct AMDGPUTensorLoadStoreOpLowering
3403 : public ConvertOpToLLVMPattern<SourceOp> {
3404 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3406 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
3407 Chipset chipset)
3408 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3409 Chipset chipset;
3410
3411 LogicalResult
3412 matchAndRewrite(SourceOp op, Adaptor adaptor,
3413 ConversionPatternRewriter &rewriter) const override {
3414 if (chipset < kGfx1250)
3415 return op->emitOpError("is only supported on gfx1250");
3416
3417 ValueRange desc = adaptor.getDesc();
3418 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3419 desc[3], /*cachePolicy=*/0,
3420 /*alias_scopes=*/nullptr,
3421 /*noalias_scopes=*/nullptr,
3422 /*tbaa=*/nullptr);
3423 return success();
3424 }
3425};
3426
3427struct ConvertAMDGPUToROCDLPass
3428 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3429 using Base::Base;
3430
3431 void runOnOperation() override {
3432 MLIRContext *ctx = &getContext();
3433 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
3434 if (failed(maybeChipset)) {
3435 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
3436 return signalPassFailure();
3437 }
3438
3439 RewritePatternSet patterns(ctx);
3440 LLVMTypeConverter converter(ctx);
3441
3442 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
3444 LLVMConversionTarget target(getContext());
3445 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3446 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3447 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3448 if (failed(applyPartialConversion(getOperation(), target,
3449 std::move(patterns))))
3450 signalPassFailure();
3451 }
3452};
3453} // namespace
3454
3456 TypeConverter &typeConverter) {
3458 typeConverter, [](gpu::AddressSpace space) {
3459 switch (space) {
3460 case gpu::AddressSpace::Global:
3461 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3462 case gpu::AddressSpace::Workgroup:
3463 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3464 case gpu::AddressSpace::Private:
3465 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3466 }
3467 llvm_unreachable("unknown address space enum value");
3468 });
3469}
3470
3472 TypeConverter &typeConverter) {
3473 typeConverter.addTypeAttributeConversion(
3474 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
3475 -> TypeConverter::AttributeConversionResult {
3476 MLIRContext *ctx = as.getContext();
3477 Type i64 = IntegerType::get(ctx, 64);
3478 switch (as.getValue()) {
3479 case amdgpu::AddressSpace::FatRawBuffer:
3480 return IntegerAttr::get(i64, 7);
3481 case amdgpu::AddressSpace::BufferRsrc:
3482 return IntegerAttr::get(i64, 8);
3483 case amdgpu::AddressSpace::FatStructuredBuffer:
3484 return IntegerAttr::get(i64, 9);
3485 }
3486 return TypeConverter::AttributeConversionResult::abort();
3487 });
3488 typeConverter.addConversion([&](TDMBaseType type) -> Type {
3489 Type i32 = IntegerType::get(type.getContext(), 32);
3490 return typeConverter.convertType(VectorType::get(4, i32));
3491 });
3492 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
3493 Type i32 = IntegerType::get(type.getContext(), 32);
3494 return typeConverter.convertType(VectorType::get(4, i32));
3495 });
3496 typeConverter.addConversion(
3497 [&](TDMDescriptorType type,
3498 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
3499 Type i32 = IntegerType::get(type.getContext(), 32);
3500 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3501 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3502 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
3503 return success();
3504 });
3505
3506 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
3507 ValueRange inputs,
3509 // Only create unrealized_conversion_cast for TDMDescriptorType.
3510 // All other types which are not expected, should be
3511 // materialized by other target materialization functions.
3512 if (inputs.size() != 1)
3513 return {};
3514
3515 if (!isa<TDMDescriptorType>(inputs[0].getType()))
3516 return {};
3517
3518 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3519 return cast.getResults();
3520 };
3521
3522 typeConverter.addTargetMaterialization(addUnrealizedCast);
3523}
3524
3527 Chipset chipset) {
3529 patterns
3530 .add<FatRawBufferCastLowering,
3531 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3532 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3533 RawBufferOpLowering<RawBufferAtomicFaddOp,
3534 ROCDL::RawPtrBufferAtomicFaddOp>,
3535 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3536 ROCDL::RawPtrBufferAtomicFmaxOp>,
3537 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3538 ROCDL::RawPtrBufferAtomicSmaxOp>,
3539 RawBufferOpLowering<RawBufferAtomicUminOp,
3540 ROCDL::RawPtrBufferAtomicUminOp>,
3541 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3542 ROCDL::RawPtrBufferAtomicCmpSwap>,
3543 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3544 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3545 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3546 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3547 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3548 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3549 GatherToLDSOpLowering, TransposeLoadOpLowering,
3550 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3551 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3552 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3553 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3554 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3555 ROCDL::TensorLoadToLDSOp>,
3556 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3557 ROCDL::TensorStoreFromLDSOp>>(
3558 converter, chipset);
3559 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
3560}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:219
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:218
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:207
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:432
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:393
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14