MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Matchers.h"
25#include "mlir/Pass/Pass.h"
26
28
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
33#include <optional>
34
35namespace mlir {
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
38} // namespace mlir
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43// Define commonly used chipsets versions for convenience.
44constexpr Chipset kGfx908 = Chipset(9, 0, 8);
45constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
46constexpr Chipset kGfx942 = Chipset(9, 4, 2);
47constexpr Chipset kGfx950 = Chipset(9, 5, 0);
48constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
49
50/// Convert an unsigned number `val` to i32.
51static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
52 Location loc, Value val) {
53 IntegerType i32 = rewriter.getI32Type();
54 // Force check that `val` is of int type.
55 auto valTy = cast<IntegerType>(val.getType());
56 if (i32 == valTy)
57 return val;
58 return valTy.getWidth() > 32
59 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61}
62
63static Value createI32Constant(ConversionPatternRewriter &rewriter,
64 Location loc, int32_t value) {
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
66}
67
68/// Convert an unsigned number `val` to i64.
69static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
70 Location loc, Value val) {
71 IntegerType i64 = rewriter.getI64Type();
72 // Force check that `val` is of int type.
73 auto valTy = cast<IntegerType>(val.getType());
74 if (i64 == valTy)
75 return val;
76 return valTy.getWidth() > 64
77 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
79}
80
81static Value createI64Constant(ConversionPatternRewriter &rewriter,
82 Location loc, int64_t value) {
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
84}
85
86/// Returns the linear index used to access an element in the memref.
87static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
88 Location loc, MemRefDescriptor &memRefDescriptor,
90 IntegerType i32 = rewriter.getI32Type();
92 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
93 if (stride != 1) { // Skip if stride is 1.
94 Value strideValue =
95 ShapedType::isDynamic(stride)
96 ? convertUnsignedToI32(rewriter, loc,
97 memRefDescriptor.stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
100 }
101 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
102 : increment;
103 }
104 return index ? index : createI32Constant(rewriter, loc, 0);
105}
106
107/// Compute the contents of the `num_records` field for a given memref
108/// descriptor - that is, the number of bytes that's one element past the
109/// greatest possible valid index into the memref.
110static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
111 MemRefType memrefType,
112 MemRefDescriptor &memrefDescriptor,
113 ArrayRef<int64_t> strides, int64_t elementByteWidth,
114 amdgpu::Chipset chipset, bool boundsCheck) {
115 if (chipset >= kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
117 return createI64Constant(rewriter, loc, first45bits);
118 }
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
122 ArrayRef<int64_t> shape = memrefType.getShape();
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(shape[i] * strides[i], size);
125 size = size * elementByteWidth;
126 return createI64Constant(rewriter, loc, size);
127 }
128 Value maxIndex;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.size(rewriter, loc, i);
131 Value stride = memrefDescriptor.stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
133 maxIndex = maxIndex
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
135 : maxThisDim;
136 }
137 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
138 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
140}
141
142static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
143 Value basePointer, Value numRecords,
144 bool boundsCheck, amdgpu::Chipset chipset,
145 Value cacheSwizzleStride = nullptr,
146 unsigned addressSpace = 8) {
147 // The stride value is generally 0. However, on MI-300 and onward, you can
148 // enable a cache swizzling mode by setting bit 14 of the stride field
149 // and setting that stride to a cache stride.
150 Type i16 = rewriter.getI16Type();
151 Value stride;
152 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
158 /*isDisjoint=*/true);
159 } else {
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
162 }
163
164 uint32_t flags = 0;
165 if (chipset >= kGfx1250) {
166 // Flag word:
167 // bit 0: swizzle
168 // bit 1: 0 means (total_offset + payload > numRecords)
169 // 1 means ((total_offset + payload >) numRecords) || ((offset +
170 // payload) > stride) only applied when swizzle_enable = 0. keep at
171 // zero.
172 // whether oob is done depends on numRecords.
173 // bits 2-3: Type (must be 0)
174 } else {
175 // Get the number of elements.
176 // Flag word:
177 // bits 0-11: dst sel, ignored by these intrinsics
178 // bits 12-14: data format (ignored, must be nonzero, 7=float)
179 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
180 // bit 19: In nested heap (0 here)
181 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
182 // bits 21-22: Index stride for swizzles (N/A)
183 // bit 23: Add thread ID (0)
184 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
185 // bits 25-26: Reserved (0)
186 // bit 27: Buffer is non-volatile (CDNA only)
187 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
188 // none, 3 = either swizzles or testing against offset field) RDNA only
189 // bits 30-31: Type (must be 0)
190 flags |= (7 << 12) | (4 << 15);
191 if (chipset.majorVersion >= 10) {
192 flags |= (1 << 24);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
195 }
196 }
197 Value flagsConst = createI32Constant(rewriter, loc, flags);
198 Type rsrcType =
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
202 return resource;
203}
204
205namespace {
206struct FatRawBufferCastLowering
207 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
208 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
210 chipset(chipset) {}
211
212 Chipset chipset;
213
214 LogicalResult
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
221 MemRefDescriptor descriptor(memRef);
222
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
225 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
226
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError("Can't lower non-stride-offset memrefs");
231
232 Value numRecords = adaptor.getValidBytes();
233 if (!numRecords)
234 numRecords =
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
237
238 Value basePointer =
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
241 memrefType)
242 : descriptor.alignedPtr(rewriter, loc);
243
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
248
249 bool hasSizes = memrefType.getRank() > 0;
250 // No need to unpack() and pack() all the individual sizes and strides,
251 // so we'll just extract the arrays.
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
255 : Value{};
256 Value strides =
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
259 : Value{};
260
261 Value fatPtr = makeBufferRsrc(
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
264
265 Value result = MemRefDescriptor::poison(
266 rewriter, loc,
267 getTypeConverter()->convertType(op.getResult().getType()));
268 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
269 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
274 if (hasSizes) {
275 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
279 }
280 rewriter.replaceOp(op, result);
281 return success();
282 }
283};
284
285/// Define lowering patterns for raw buffer ops
286template <typename GpuOp, typename Intrinsic>
287struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
288 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
290
291 Chipset chipset;
292 static constexpr uint32_t maxVectorOpWidth = 128;
293
294 LogicalResult
295 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
301
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
304
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref) // no write component to this op
307 storeData = Value();
308 Type wantedDataType;
309 if (storeData)
310 wantedDataType = storeData.getType();
311 else
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
313
314 Value atomicCmpData = Value();
315 // Operand index 1 of a load is the indices, trying to read them can crash.
316 if (storeData) {
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
320 }
321
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
323
324 Type i32 = rewriter.getI32Type();
325
326 // Get the type size in bytes.
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
329 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
330 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
331
332 // If we want to load a vector<NxT> with total size <= 32
333 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
334 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
335 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
336 // so bitcast any floats to integers.
337 Type llvmBufferValType = llvmWantedDataType;
338 if (atomicCmpData) {
339 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
342 }
343 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
345 uint32_t elemBits =
346 dataLayout.getTypeSizeInBits(dataVector.getElementType());
347 uint32_t totalBits = elemBits * vecLen;
348 bool usePackedFp16 =
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) + " bits, but we call for " +
354 Twine(totalBits) +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError("Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
363 } else {
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
366 }
367 }
368 }
369 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
370 // Buffer intrinsics doesn't support 1-element vectors, cast them to
371 // scalars.
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
374 }
375
376 SmallVector<Value, 6> args;
377 if (storeData) {
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
382 } else {
383 args.push_back(storeData);
384 }
385 }
386
387 if (atomicCmpData) {
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
392 } else {
393 args.push_back(atomicCmpData);
394 }
395 }
396
397 // Construct buffer descriptor from memref, attributes
398 int64_t offset = 0;
399 SmallVector<int64_t, 5> strides;
400 if (failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
402
403 MemRefDescriptor memrefDescriptor(memref);
404
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
407 Value numRecords =
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
410 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
413
414 // Indexing (voffset)
415 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
419 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
421 extraOffsetConst)
422 : extraOffsetConst;
423 }
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
426
427 // SGPR offset.
428 Value sgprOffset = adaptor.getSgprOffset();
429 if (!sgprOffset)
430 sgprOffset = createI32Constant(rewriter, loc, 0);
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
433
434 // bit 0: GLC = 0 (atomics drop value, less coherency)
435 // bits 1-2: SLC, DLC = 0 (similarly)
436 // bit 3: swizzled (0 for raw)
437 args.push_back(createI32Constant(rewriter, loc, 0));
438
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
440 llvmBufferValType);
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
443 if (lowered->getNumResults() == 1) {
444 Value replacement = lowered->getResult(0);
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
448 }
449 rewriter.replaceOp(gpuOp, replacement);
450 } else {
451 rewriter.eraseOp(gpuOp);
452 }
453 return success();
454 }
455};
456
457// TODO: AMDGPU backend already have all this bitpacking logic, we should move
458// it to some common place.
459/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
460/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
461/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
462/// Vmcnt = Waitcnt[15:10] (gfx11)
463/// Expcnt = Waitcnt[6:4] (pre-gfx11)
464/// Expcnt = Waitcnt[2:0] (gfx11)
465/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
466/// Lgkmcnt = Waitcnt[13:8] (gfx10)
467/// Lgkmcnt = Waitcnt[9:4] (gfx11)
468static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
469 unsigned expcnt, unsigned lgkmcnt) {
470 if (chipset.majorVersion < 9) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
475 }
476 if (chipset.majorVersion == 9) {
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
484 }
485 if (chipset.majorVersion == 10) {
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
493 }
494 if (chipset.majorVersion == 11) {
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
499 }
500 return failure();
501}
502
503struct MemoryCounterWaitOpLowering
504 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
505 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
506 Chipset chipset)
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
508 chipset(chipset) {}
509
510 Chipset chipset;
511
512 LogicalResult
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
519
520 if (std::optional<int> load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
522
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
525
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
528
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
531
532 rewriter.eraseOp(op);
533 return success();
534 }
535
536 if (adaptor.getTensor())
537 return op.emitOpError("unsupported chipset");
538
539 auto getVal = [](Attribute attr) -> unsigned {
540 if (attr)
541 return cast<IntegerAttr>(attr).getInt();
542
543 // This value will be clamped to the maximum value for the chipset.
544 return 1024;
545 };
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
548
549 unsigned vmcnt = 1024;
550 Attribute load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
552 if (load && store) {
553 vmcnt = getVal(load) + getVal(store);
554 } else if (load) {
555 vmcnt = getVal(load);
556 } else if (store) {
557 vmcnt = getVal(store);
558 }
559
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
561 if (failed(waitcnt))
562 return op.emitOpError("unsupported chipset");
563
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
565 return success();
566 }
567};
568
569struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
570 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
572
573 Chipset chipset;
574
575 LogicalResult
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter) const override {
578 Location loc = op.getLoc();
579 // This ensures that waits on global memory aren't introduced on
580 // chips that don't have the BackOffBarrier feature enabled in LLVM.
581 bool requiresInlineAsm = chipset < kGfx90a;
582
583 Attribute mmra =
584 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
585 // Note: while there *is* a workgroup-one-as scope, this, when combined with
586 // the MMRA, will lead to the fence having no effect. This is because the
587 // codepaths for an atomic load or store will observe that a
588 // one-address-space atomic to LDS requires no synchronization because
589 // operations on LDS are totally ordered with respect to each other, and so
590 // will not emit the correct waitcnt operations that these fences are
591 // intended to produce. Therefore, we use a broader type of fence and rely
592 // on the MMRA to relax it to the semantics we want.
593 StringRef scope = "workgroup";
594
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints = "";
603 LLVM::InlineAsmOp::create(
604 rewriter, loc,
605 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
606 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
607 /*is_align_stack=*/false, LLVM::TailCallKind::None,
608 /*asm_dialect=*/asmDialectAttr,
609 /*operand_attrs=*/ArrayAttr());
610 } else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
612 } else {
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
615 }
616
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
621 return success();
622 }
623};
624
625struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
626 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
628
629 Chipset chipset;
630
631 LogicalResult
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
636 return success();
637 }
638};
639
640} // namespace
641
642/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
643/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
644///
645/// Specifically:
646/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
647/// allows bf16. Newer MFMAs support bf16 types on operand, check
648/// IntrinsicsAMDGPU.td file for reference.
649/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
650/// instead, which is what the f8f6f4 intrinsics use.
651/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
652/// integer.
653///
654/// Note that the type of `input` has already been LLVM type converted:
655/// therefore 8-bit and smaller floats are represented as their corresponding
656/// `iN` integers.
657static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
658 Location loc, Value input,
659 bool allowBf16 = true) {
660 Type inputType = input.getType();
661 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
668 rewriter, loc,
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
674 32);
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
677 input);
678 }
679 }
680 return input;
681}
682
683/// Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL
684/// types.
685static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter,
686 Location loc, Value input,
687 bool allowBf16 = true) {
688 Type inputType = input.getType();
689 auto vectorType = cast<VectorType>(inputType);
690 // bf16 -> i16 when not allowed (pre-gfx950).
691 if (vectorType.getElementType().isBF16() && !allowBf16)
692 return LLVM::BitcastOp::create(
693 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
694 // i8/fp8 vectors -> vector<Nxi32>.
695 if (isa<IntegerType>(vectorType.getElementType()) &&
696 vectorType.getElementTypeBitWidth() <= 8) {
697 int64_t numWords = llvm::divideCeil(
698 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
699 Type castType = (numWords > 1)
700 ? Type{VectorType::get(numWords, rewriter.getI32Type())}
701 : rewriter.getI32Type();
702 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
703 }
704 return input;
705}
706
707/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
708/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
709///
710/// Specifically:
711/// 1. If `input` is a i8 value, zero extend it to i32
712/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
713///
714/// Note that the type of `input` has already been LLVM type converted:
715/// therefore 8-bit and smaller floats are represented as their corresponding
716/// `iN` integers.
717static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
718 Value input) {
719 return TypeSwitch<Type, Value>(input.getType())
720 .Case([&](IntegerType) {
721 // Handle scalar i8: zero extend to i32.
722 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
723 input);
724 })
725 .Case([&](VectorType vectorType) {
726 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
727 int64_t numElements = vectorType.getNumElements();
728 assert((numElements == 4 || numElements == 8) &&
729 "scale operand must be a vector of length 4 or 8");
730 IntegerType outputType =
731 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
732 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
733 })
734 .DefaultUnreachable("unexpected input type for scale operand");
735}
736
737/// Maps f8 scale element types to WMMA scale format codes.
738static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
740 .Case([](Float8E8M0FNUType) { return 0; })
741 .Case([](Float8E4M3FNType) { return 2; })
742 .Default(std::nullopt);
743}
744
745/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
746/// and scale block size (16 or 32).
747static std::optional<StringRef>
749 if (m == 16 && n == 16 && k == 128)
750 return isScale16
751 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
752 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
753
754 if (m == 32 && n == 16 && k == 128)
755 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
756 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
757
758 return std::nullopt;
759}
760
761/// Push an input operand. If it is a float type, nothing to do. If it is
762/// an integer type, then we need to also push its signdness (1 for signed, 0
763/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
764/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
765/// We also need to convert bfloat inputs to i16 to account for the bfloat
766/// intrinsics having been defined before the AMD backend supported bfloat. We
767/// similarly need to pack 8-bit float types into integers as if they were i8
768/// (which they are for the backend's purposes).
770 ConversionPatternRewriter &rewriter, Location loc,
771 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
772 Value mlirInput, SmallVectorImpl<Value> &operands,
773 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
774 Type inputType = llvmInput.getType();
775 auto vectorType = dyn_cast<VectorType>(inputType);
776 if (!vectorType) {
777 operands.push_back(llvmInput);
778 return;
779 }
780 Type elemType = vectorType.getElementType();
781 if (elemType.getIntOrFloatBitWidth() > 8) {
782 operands.push_back(llvmInput);
783 return;
784 }
785
786 // We need to check the type of the input before conversion to properly test
787 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
788 // fp8/int8 information is lost during the conversion process.
789 auto mlirInputType = cast<VectorType>(mlirInput.getType());
790 bool isInputInteger = mlirInputType.getElementType().isInteger();
791 if (isInputInteger) {
792 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
793 bool localIsUnsigned = isUnsigned;
794 if (elemType.isUnsignedInteger()) {
795 localIsUnsigned = true;
796 } else if (elemType.isSignedInteger()) {
797 localIsUnsigned = false;
798 }
799 attrs.push_back(
800 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
801 }
802
803 int64_t numBits =
804 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
805 Type i32 = rewriter.getI32Type();
806 Type intrinsicInType = numBits <= 32
807 ? (Type)rewriter.getIntegerType(numBits)
808 : (Type)VectorType::get(numBits / 32, i32);
809 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
810 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
811 loc, llvmIntrinsicInType, llvmInput);
812 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
813 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
814 // Add in the zeros here.
815 if (numBits < 32)
816 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
817 operands.push_back(castInput);
818}
819
820/// Push the output operand. For many cases this is only pushing the output in
821/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
822/// since the same numbers of VGPRs is used, we need to decide if to store the
823/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
824/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
825/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
826/// as the instructions have been changed to return fewer registers instead.
827static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
828 Location loc,
829 const TypeConverter *typeConverter,
830 Value output, int32_t subwordOffset,
831 bool clamp, SmallVectorImpl<Value> &operands,
833 Type inputType = output.getType();
834 auto vectorType = dyn_cast<VectorType>(inputType);
835 Type elemType = vectorType.getElementType();
836 operands.push_back(output);
837 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
838 attrs.push_back(
839 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
840 } else if (elemType.isInteger(32)) {
841 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
842 }
843}
844
845/// Return true if `type` is the E5M2 variant of an 8-bit float that is
846/// supported by the `_bf8` instructions on the given `chipset`.
847static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
848 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
849 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
850}
851
852/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
853/// supported by the `_fp8` instructions on the given `chipset`.
854static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
855 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
856 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
857}
858
859/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
860/// if one exists. This includes checking to ensure the intrinsic is supported
861/// on the architecture you are compiling for.
862static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
863 Chipset chipset) {
864 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
865 b = mfma.getBlocks();
866 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
867 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
868
869 if (sourceElem.isF32() && destElem.isF32()) {
870 if (mfma.getReducePrecision() && chipset >= kGfx942) {
871 if (m == 32 && n == 32 && k == 4 && b == 1)
872 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
873 if (m == 16 && n == 16 && k == 8 && b == 1)
874 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
875 }
876 if (m == 32 && n == 32 && k == 1 && b == 2)
877 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
878 if (m == 16 && n == 16 && k == 1 && b == 4)
879 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
880 if (m == 4 && n == 4 && k == 1 && b == 16)
881 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
882 if (m == 32 && n == 32 && k == 2 && b == 1)
883 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
884 if (m == 16 && n == 16 && k == 4 && b == 1)
885 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
886 }
887
888 if (sourceElem.isF16() && destElem.isF32()) {
889 if (chipset >= kGfx950) {
890 if (m == 32 && n == 32 && k == 16 && b == 1)
891 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
892 if (m == 16 && n == 16 && k == 32 && b == 1)
893 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
894 }
895 if (m == 32 && n == 32 && k == 4 && b == 2)
896 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
897 if (m == 16 && n == 16 && k == 4 && b == 4)
898 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
899 if (m == 4 && n == 4 && k == 4 && b == 16)
900 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
901 if (m == 32 && n == 32 && k == 8 && b == 1)
902 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
903 if (m == 16 && n == 16 && k == 16 && b == 1)
904 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
905 }
906
907 if (sourceElem.isBF16() && destElem.isF32()) {
908 if (chipset >= kGfx950) {
909 if (m == 32 && n == 32 && k == 16 && b == 1)
910 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
911 if (m == 16 && n == 16 && k == 32 && b == 1)
912 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
913 }
914 if (chipset >= kGfx90a) {
915 if (m == 32 && n == 32 && k == 4 && b == 2)
916 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
917 if (m == 16 && n == 16 && k == 4 && b == 4)
918 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
919 if (m == 4 && n == 4 && k == 4 && b == 16)
920 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
921 if (m == 32 && n == 32 && k == 8 && b == 1)
922 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
923 if (m == 16 && n == 16 && k == 16 && b == 1)
924 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
925 }
926 if (m == 32 && n == 32 && k == 2 && b == 2)
927 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
928 if (m == 16 && n == 16 && k == 2 && b == 4)
929 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
930 if (m == 4 && n == 4 && k == 2 && b == 16)
931 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
932 if (m == 32 && n == 32 && k == 4 && b == 1)
933 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
934 if (m == 16 && n == 16 && k == 8 && b == 1)
935 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
936 }
937
938 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
939 if (chipset >= kGfx950) {
940 if (m == 32 && n == 32 && k == 32 && b == 1)
941 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
942 if (m == 16 && n == 16 && k == 64 && b == 1)
943 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
944 }
945 if (m == 32 && n == 32 && k == 4 && b == 2)
946 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
947 if (m == 16 && n == 16 && k == 4 && b == 4)
948 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
949 if (m == 4 && n == 4 && k == 4 && b == 16)
950 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
951 if (m == 32 && n == 32 && k == 8 && b == 1)
952 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
953 if (m == 16 && n == 16 && k == 16 && b == 1)
954 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
955 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
956 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
957 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
958 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
959 }
960
961 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
962 if (m == 16 && n == 16 && k == 4 && b == 1)
963 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
964 if (m == 4 && n == 4 && k == 4 && b == 4)
965 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
966 }
967
968 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
969 // Known to be correct because there are no scalar f8 instructions and
970 // because a length mismatch will have been caught by the verifier.
971 Type sourceBElem =
972 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
973 if (m == 16 && n == 16 && k == 32 && b == 1) {
974 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
975 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
976 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
977 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
978 }
979 if (m == 32 && n == 32 && k == 16 && b == 1) {
980 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
981 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
982 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
983 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
984 }
985 }
986
987 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
988 Type sourceBElem =
989 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
990 if (m == 16 && n == 16 && k == 32 && b == 1) {
991 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
992 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
993 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
994 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
995 }
996 if (m == 32 && n == 32 && k == 16 && b == 1) {
997 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
998 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
999 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1000 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1001 }
1002 }
1003
1004 return std::nullopt;
1005}
1006
1007static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1009 .Case([](Float8E4M3FNType) { return 0u; })
1010 .Case([](Float8E5M2Type) { return 1u; })
1011 .Case([](Float6E2M3FNType) { return 2u; })
1012 .Case([](Float6E3M2FNType) { return 3u; })
1013 .Case([](Float4E2M1FNType) { return 4u; })
1014 .Default(std::nullopt);
1015}
1016
1017/// If there is a scaled MFMA instruction for the input element types `aType`
1018/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1019/// blocks) on the given `chipset`, return a tuple consisting of the
1020/// OperationName of the intrinsic and the type codes that need to be passed to
1021/// that intrinsic. Note that this is also used to implement some un-scaled
1022/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1023/// MFMA with a scale of 0.
1024static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1025mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1026 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1027 aType = getElementTypeOrSelf(aType);
1028 bType = getElementTypeOrSelf(bType);
1029 destType = getElementTypeOrSelf(destType);
1030
1031 if (chipset < kGfx950)
1032 return std::nullopt;
1033 if (!isa<Float32Type>(destType))
1034 return std::nullopt;
1035
1036 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1037 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1038 if (!aTypeCode || !bTypeCode)
1039 return std::nullopt;
1040
1041 if (m == 32 && n == 32 && k == 64 && b == 1)
1042 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1043 *aTypeCode, *bTypeCode};
1044 if (m == 16 && n == 16 && k == 128 && b == 1)
1045 return std::tuple{
1046 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1047 *bTypeCode};
1048
1049 return std::nullopt;
1050}
1051
1052static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1053mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1055 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1056 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1057 mfma.getBlocks(), chipset);
1058}
1059
1060static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1061mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1062 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1063 smfma.getSourceB().getType(),
1064 smfma.getDestC().getType(), smfma.getM(),
1065 smfma.getN(), smfma.getK(), 1u, chipset);
1066}
1067
1068/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1069/// for RDNA3/4 architectures.
1070static std::optional<StringRef>
1071wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1072 Type elemDestType, uint32_t k, bool isRDNA3) {
1073 using fp8 = Float8E4M3FNType;
1074 using bf8 = Float8E5M2Type;
1075
1076 // Handle k == 16 for RDNA3/4.
1077 if (k == 16) {
1078 // Common patterns for RDNA3 and RDNA4.
1079 if (elemSourceType.isF16() && elemDestType.isF32())
1080 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1081 if (elemSourceType.isBF16() && elemDestType.isF32())
1082 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1083 if (elemSourceType.isF16() && elemDestType.isF16())
1084 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1085 if (elemSourceType.isBF16() && elemDestType.isBF16())
1086 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1087 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1088 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1089
1090 // RDNA3 specific patterns.
1091 if (isRDNA3) {
1092 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1093 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1094 return std::nullopt;
1095 }
1096
1097 // RDNA4 specific patterns (fp8/bf8).
1098 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1099 elemDestType.isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1101 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.isF32())
1103 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1105 elemDestType.isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1107 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1108 elemDestType.isF32())
1109 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1110 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1111 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1112
1113 return std::nullopt;
1114 }
1115
1116 // Handle k == 32 for RDNA4.
1117 if (k == 32 && !isRDNA3) {
1118 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1119 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1120 }
1121
1122 return std::nullopt;
1123}
1124
1125/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1126/// for the gfx1250 architecture.
1127static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1128 Type elemBSourceType,
1129 Type elemDestType,
1130 uint32_t k) {
1131 using fp8 = Float8E4M3FNType;
1132 using bf8 = Float8E5M2Type;
1133
1134 if (k == 4) {
1135 if (elemSourceType.isF32() && elemDestType.isF32())
1136 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1137
1138 return std::nullopt;
1139 }
1140
1141 if (k == 32) {
1142 if (elemSourceType.isF16() && elemDestType.isF32())
1143 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1144 if (elemSourceType.isBF16() && elemDestType.isF32())
1145 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1146 if (elemSourceType.isF16() && elemDestType.isF16())
1147 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1148 if (elemSourceType.isBF16() && elemDestType.isBF16())
1149 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1150
1151 return std::nullopt;
1152 }
1153
1154 if (k == 64) {
1155 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1156 if (elemDestType.isF32())
1157 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1158 if (elemDestType.isF16())
1159 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1160 }
1161 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1162 if (elemDestType.isF32())
1163 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1164 if (elemDestType.isF16())
1165 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1166 }
1167 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1168 if (elemDestType.isF32())
1169 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1170 if (elemDestType.isF16())
1171 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1172 }
1173 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1174 if (elemDestType.isF32())
1175 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1176 if (elemDestType.isF16())
1177 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1178 }
1179 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1180 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1181
1182 return std::nullopt;
1183 }
1184
1185 if (k == 128) {
1186 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1187 if (elemDestType.isF32())
1188 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1189 if (elemDestType.isF16())
1190 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1191 }
1192 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1193 if (elemDestType.isF32())
1194 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1195 if (elemDestType.isF16())
1196 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1197 }
1198 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1199 if (elemDestType.isF32())
1200 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1201 if (elemDestType.isF16())
1202 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1203 }
1204 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1205 if (elemDestType.isF32())
1206 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1207 if (elemDestType.isF16())
1208 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1209 }
1210
1211 return std::nullopt;
1212 }
1213
1214 return std::nullopt;
1215}
1216
1217/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1218/// operation if one exists. This includes checking to ensure the intrinsic is
1219/// supported on the architecture you are compiling for.
1220static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1221 Chipset chipset) {
1222 bool isGfx950 = chipset >= kGfx950;
1223 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1224 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1225
1226 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1227 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1228 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1229 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1230
1231 if (m == 16 && n == 16 && k == 32) {
1232 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1233 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1234 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1235 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1236 }
1237
1238 if (m == 16 && n == 16 && k == 64) {
1239 if (isGfx950) {
1240 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1241 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1242 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1243 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1244 }
1245 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1246 destElem.isInteger(32))
1247 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1248 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1249 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1250 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1251 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1252 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1253 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1254 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1255 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1256 }
1257
1258 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1259 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1260 destElem.isInteger(32))
1261 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1262 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1263 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1264 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1265 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1266 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1267 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1268 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1269 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1270 }
1271
1272 if (m == 32 && n == 32 && k == 16) {
1273 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1274 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1275 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1276 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1277 }
1278
1279 if (m == 32 && n == 32 && k == 32) {
1280 if (isGfx950) {
1281 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1282 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1283 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1284 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1285 }
1286 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1287 destElem.isInteger(32))
1288 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1289 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1290 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1291 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1292 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1293 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1294 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1295 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1296 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1297 }
1298
1299 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1300 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1301 destElem.isInteger(32))
1302 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1303 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1304 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1305 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1306 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1307 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1308 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1309 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1310 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1311 }
1312
1313 return std::nullopt;
1314}
1315
1316/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1317/// if one exists. This includes checking to ensure the intrinsic is supported
1318/// on the architecture you are compiling for.
1319static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1320 Chipset chipset) {
1321 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1322 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1323 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1324 Type elemSourceType = sourceVectorType.getElementType();
1325 Type elemBSourceType = sourceBVectorType.getElementType();
1326 Type elemDestType = destVectorType.getElementType();
1327
1328 const uint32_t k = wmma.getK();
1329 const bool isRDNA3 = chipset.majorVersion == 11;
1330 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1331
1332 // Handle RDNA3 and RDNA4.
1333 if (isRDNA3 || isRDNA4)
1334 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1335 k, isRDNA3);
1336
1337 // Handle gfx1250.
1338 if (chipset == kGfx1250)
1339 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1340 elemDestType, k);
1341
1342 return std::nullopt;
1343}
1344
1345/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
1346/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
1347/// supported on the architecture you are compiling for.
1349 StringRef name;
1353};
1354
1355static std::optional<SparseWMMAOpInfo>
1356sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
1357 Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
1358 Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
1359 Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
1360
1361 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1362
1363 if ((m != 16) || (n != 16))
1364 return std::nullopt;
1365
1366 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1367 if (isRDNA4) {
1368 if (k == 32) {
1369 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1370 return SparseWMMAOpInfo{
1371 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
1372 false};
1373 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1374 return SparseWMMAOpInfo{
1375 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
1376 false};
1377 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1378 return SparseWMMAOpInfo{
1379 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
1380 false};
1381 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1382 return SparseWMMAOpInfo{
1383 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
1384 false};
1385 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1386 sourceBElem.isInteger(8))
1387 return SparseWMMAOpInfo{
1388 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
1389 true};
1390 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1391 sourceBElem.isInteger(4))
1392 return SparseWMMAOpInfo{
1393 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
1394 true};
1395 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1396 sourceBElem.isF8E4M3FN())
1397 return SparseWMMAOpInfo{
1398 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
1399 false, false};
1400 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1401 sourceBElem.isF8E5M2())
1402 return SparseWMMAOpInfo{
1403 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
1404 false, false};
1405 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1406 sourceBElem.isF8E4M3FN())
1407 return SparseWMMAOpInfo{
1408 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
1409 false, false};
1410 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1411 return SparseWMMAOpInfo{
1412 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
1413 false, false};
1414 }
1415 if (k == 64) {
1416 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1417 sourceBElem.isInteger(4))
1418 return SparseWMMAOpInfo{
1419 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
1420 true};
1421 }
1422 }
1423
1424 const bool isGFX1250 = chipset == kGfx1250;
1425 const bool isWavesize64 = swmmac.getWave64();
1426 if (isGFX1250 && !isWavesize64) {
1427 if (k == 64) {
1428 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1429 return SparseWMMAOpInfo{
1430 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
1431 false};
1432 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1433 return SparseWMMAOpInfo{
1434 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
1435 false};
1436 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1437 return SparseWMMAOpInfo{
1438 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
1439 false};
1440 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1441 return SparseWMMAOpInfo{
1442 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
1443 false};
1444 }
1445 if (k == 128) {
1446 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1447 sourceBElem.isF8E4M3FN())
1448 return SparseWMMAOpInfo{
1449 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
1450 true, false};
1451 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1452 sourceBElem.isF8E5M2())
1453 return SparseWMMAOpInfo{
1454 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
1455 true, false};
1456 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1457 sourceBElem.isF8E4M3FN())
1458 return SparseWMMAOpInfo{
1459 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
1460 true, false};
1461 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1462 return SparseWMMAOpInfo{
1463 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
1464 true, false};
1465 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1466 sourceBElem.isF8E4M3FN())
1467 return SparseWMMAOpInfo{
1468 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
1469 true, false};
1470 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1471 sourceBElem.isF8E5M2())
1472 return SparseWMMAOpInfo{
1473 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
1474 true, false};
1475 if (destElem.isF16() && sourceAElem.isF8E5M2() &&
1476 sourceBElem.isF8E4M3FN())
1477 return SparseWMMAOpInfo{
1478 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
1479 true, false};
1480 if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1481 return SparseWMMAOpInfo{
1482 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1483 true, false};
1484 if (destElem.isF16() && sourceAElem.isInteger(8) &&
1485 sourceBElem.isInteger(8))
1486 return SparseWMMAOpInfo{
1487 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1488 true, false};
1489 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1490 sourceBElem.isInteger(8))
1491 return SparseWMMAOpInfo{
1492 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
1493 true};
1494 }
1495 }
1496
1497 return std::nullopt;
1498}
1499
1500namespace {
1501struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1502 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1503 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1504
1505 Chipset chipset;
1506
1507 LogicalResult
1508 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1509 ConversionPatternRewriter &rewriter) const override {
1510 Location loc = op.getLoc();
1511 Type outType = typeConverter->convertType(op.getDestD().getType());
1512 Type intrinsicOutType = outType;
1513 if (auto outVecType = dyn_cast<VectorType>(outType))
1514 if (outVecType.getElementType().isBF16())
1515 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1516
1517 if (chipset.majorVersion != 9 || chipset < kGfx908)
1518 return op->emitOpError("MFMA only supported on gfx908+");
1519 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1520 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1521 if (chipset < kGfx942)
1522 return op.emitOpError("negation unsupported on older than gfx942");
1523 getBlgpField |=
1524 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1525 }
1526 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1527 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1528 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1529 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1530 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1531
1532 bool isScaled =
1533 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1534 if (isScaled &&
1535 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1536 return op.emitOpError(
1537 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1538 "be scaled as those fields are used for type information");
1539 }
1540
1541 StringRef intrinsicName =
1542 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1543 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1544 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1545 bool allowBf16 = [&]() {
1546 if (chipset < kGfx950)
1547 return false;
1548 if (isScaled)
1549 return true;
1550 return intrinsicName.contains("16x16x32.bf16") ||
1551 intrinsicName.contains("32x32x16.bf16");
1552 }();
1553 OperationState loweredOp(loc, intrinsicName);
1554 loweredOp.addTypes(intrinsicOutType);
1555 loweredOp.addOperands({packSmallFloatVectorOperand(
1556 rewriter, loc, adaptor.getSourceA(), allowBf16),
1558 rewriter, loc, adaptor.getSourceB(), allowBf16),
1559 adaptor.getDestC()});
1560 if (isScaled) {
1561 Value zero = createI32Constant(rewriter, loc, 0);
1562 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1563 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1564 loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1565 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1566 {"opselA", rewriter.getI32IntegerAttr(0)},
1567 {"opselB", rewriter.getI32IntegerAttr(0)}});
1568 } else {
1569 loweredOp.addAttributes(
1570 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1571 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1572 {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1573 };
1574 Value lowered = rewriter.create(loweredOp)->getResult(0);
1575 if (outType != intrinsicOutType)
1576 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1577 rewriter.replaceOp(op, lowered);
1578 return success();
1579 }
1580};
1581
1582struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1583 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1584 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1585
1586 Chipset chipset;
1587
1588 LogicalResult
1589 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1590 ConversionPatternRewriter &rewriter) const override {
1591 Location loc = op.getLoc();
1592 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1593
1594 if (chipset.majorVersion != 9 || chipset < kGfx950)
1595 return op->emitOpError("scaled MFMA only supported on gfx908+");
1596 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1597 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1598 if (!maybeScaledIntrinsic.has_value())
1599 return op.emitOpError(
1600 "no intrinsic matching scaled MFMA size on given chipset");
1601
1602 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1603 OperationState loweredOp(loc, intrinsicName);
1604 loweredOp.addTypes(intrinsicOutType);
1605 loweredOp.addOperands(
1606 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1607 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1608 adaptor.getDestC()});
1609 loweredOp.addOperands(
1610 {/*scales A*/
1611 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1612 /*scales B*/
1613 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1614 loweredOp.addAttributes(
1615 {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1616 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1617 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1618 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1619
1620 Value lowered = rewriter.create(loweredOp)->getResult(0);
1621 rewriter.replaceOp(op, lowered);
1622 return success();
1623 }
1624};
1625
1626struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1627 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1628 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1629
1630 Chipset chipset;
1631
1632 LogicalResult
1633 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1634 ConversionPatternRewriter &rewriter) const override {
1635 Location loc = op.getLoc();
1636 auto outType =
1637 typeConverter->convertType<VectorType>(op.getDestC().getType());
1638 if (!outType)
1639 return rewriter.notifyMatchFailure(op, "type conversion failed");
1640
1641 // smfmac is supported on gfx942 and gfx950.
1642 if (chipset.majorVersion != 9 || chipset < kGfx942)
1643 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1644 bool isGfx950 = chipset >= kGfx950;
1645
1646 Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
1647 isGfx950);
1648 Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
1649 isGfx950);
1650 Value c = adaptor.getDestC();
1651
1652 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1653 if (!maybeIntrinsic.has_value())
1654 return op.emitOpError(
1655 "no intrinsic matching sparse MFMA on the given chipset");
1656
1657 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1658 Value sparseIdx = LLVM::BitcastOp::create(
1659 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1660
1661 OperationState loweredOp(loc, maybeIntrinsic.value());
1662 loweredOp.addTypes(outType);
1663 loweredOp.addOperands({a, b, c, sparseIdx});
1664 loweredOp.addAttributes(
1665 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1666 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1667 Value lowered = rewriter.create(loweredOp)->getResult(0);
1668 rewriter.replaceOp(op, lowered);
1669 return success();
1670 }
1671};
1672
1673struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1674 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1675 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1676
1677 Chipset chipset;
1678
1679 LogicalResult
1680 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1681 ConversionPatternRewriter &rewriter) const override {
1682 Location loc = op.getLoc();
1683 auto outType =
1684 typeConverter->convertType<VectorType>(op.getDestD().getType());
1685 if (!outType)
1686 return rewriter.notifyMatchFailure(op, "type conversion failed");
1687
1688 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1689 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1690
1691 bool isGFX1250 = chipset >= kGfx1250;
1692
1693 // The WMMA operations represent vectors of bf16s as vectors of i16s
1694 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1695 // bitcast them back.
1696 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1697 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1698 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1699 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1700 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1701 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1702 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1703 VectorType rawOutType = outType;
1704 if (castOutToI16)
1705 rawOutType = outType.clone(rewriter.getI16Type());
1706 Value a = adaptor.getSourceA();
1707 if (castAToI16)
1708 a = LLVM::BitcastOp::create(rewriter, loc,
1709 aType.clone(rewriter.getI16Type()), a);
1710 Value b = adaptor.getSourceB();
1711 if (castBToI16)
1712 b = LLVM::BitcastOp::create(rewriter, loc,
1713 bType.clone(rewriter.getI16Type()), b);
1714 Value destC = adaptor.getDestC();
1715 if (castDestCToI16)
1716 destC = LLVM::BitcastOp::create(
1717 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1718
1719 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1720
1721 if (!maybeIntrinsic.has_value())
1722 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1723
1724 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1725 return op.emitOpError("subwordOffset not supported on gfx12+");
1726
1727 SmallVector<Value, 4> operands;
1728 SmallVector<NamedAttribute, 4> attrs;
1729 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1730 op.getSourceA(), operands, attrs, "signA");
1731 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1732 op.getSourceB(), operands, attrs, "signB");
1733 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1734 op.getSubwordOffset(), op.getClamp(), operands,
1735 attrs);
1736
1737 OperationState loweredOp(loc, *maybeIntrinsic);
1738 loweredOp.addTypes(rawOutType);
1739 loweredOp.addOperands(operands);
1740 loweredOp.addAttributes(attrs);
1741 Operation *lowered = rewriter.create(loweredOp);
1742
1743 Operation *maybeCastBack = lowered;
1744 if (rawOutType != outType)
1745 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1746 lowered->getResult(0));
1747 rewriter.replaceOp(op, maybeCastBack->getResults());
1748
1749 return success();
1750 }
1751};
1752
1753struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
1754 SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1755 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1756
1757 Chipset chipset;
1758
1759 LogicalResult
1760 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1761 ConversionPatternRewriter &rewriter) const override {
1762 Location loc = op.getLoc();
1763 auto outType =
1764 typeConverter->convertType<VectorType>(op.getDestD().getType());
1765 if (!outType)
1766 return rewriter.notifyMatchFailure(op, "type conversion failed");
1767
1768 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1769 sparseWMMAOpToIntrinsic(op, chipset);
1770
1771 if (!maybeIntrinsic.has_value())
1772 return op.emitOpError(
1773 "no intrinsic matching Sparse WMMA on the given chipset");
1774 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1775
1776 SmallVector<NamedAttribute> attrs;
1777
1778 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
1779 return op->emitOpError("intrinsic doesn't support unsign");
1780 if (intrinsic.useSign) {
1781 if (auto attr = op.getUnsignedAAttr())
1782 attrs.push_back({"signA", attr});
1783 if (auto attr = op.getUnsignedBAttr())
1784 attrs.push_back({"signB", attr});
1785 }
1786
1787 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
1788 return op->emitOpError("intrinsic doesn't support reuse");
1789 if (intrinsic.useReuse) {
1790 if (auto attr = op.getReuseAAttr())
1791 attrs.push_back({"reuseA", attr});
1792 if (auto attr = op.getReuseBAttr())
1793 attrs.push_back({"reuseB", attr});
1794 }
1795
1796 if (op.getClamp() && !intrinsic.useClamp)
1797 return op->emitOpError("intrinsic doesn't support clamp");
1798 if (intrinsic.useClamp && op.getClampAttr())
1799 attrs.push_back({"clamp", op.getClampAttr()});
1800
1801 const bool isGFX1250orHigher =
1802 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
1803 Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
1804 isGFX1250orHigher);
1805 Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
1806 isGFX1250orHigher);
1807 Value c = adaptor.getDestC();
1808 VectorType rawOutType = outType;
1809 if (!isGFX1250orHigher) {
1810 c = convertSparseVectorOperand(rewriter, loc, adaptor.getDestC(), false);
1811 rawOutType = cast<VectorType>(c.getType());
1812 }
1813
1814 // Bitcast sparse indices from vector<4xi8> to i32.
1815 Value sparseIdx = LLVM::BitcastOp::create(
1816 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1817
1818 OperationState loweredOp(loc, intrinsic.name);
1819 loweredOp.addTypes(rawOutType);
1820 loweredOp.addOperands({a, b, c, sparseIdx});
1821 loweredOp.addAttributes(attrs);
1822 Operation *lowered = rewriter.create(loweredOp);
1823
1824 Operation *maybeCastBack = lowered;
1825 if (rawOutType != outType)
1826 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1827 lowered->getResult(0));
1828 rewriter.replaceOp(op, maybeCastBack->getResults());
1829
1830 return success();
1831 }
1832};
1833
1834struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1835 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1836 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1837
1838 Chipset chipset;
1839
1840 LogicalResult
1841 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1842 ConversionPatternRewriter &rewriter) const override {
1843 Location loc = op.getLoc();
1844 auto outType =
1845 typeConverter->convertType<VectorType>(op.getDestD().getType());
1846 if (!outType)
1847 return rewriter.notifyMatchFailure(op, "type conversion failed");
1848
1849 if (chipset < kGfx1250)
1850 return op->emitOpError("WMMA scale only supported on gfx1250+");
1851
1852 int64_t m = op.getM();
1853 int64_t n = op.getN();
1854 int64_t k = op.getK();
1855
1856 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1857 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1858
1859 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1860 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1861
1862 if (!aFmtCode || !bFmtCode)
1863 return op.emitOpError("unsupported element types for scaled_wmma");
1864
1865 // Get scale vector types and determine variant (scale vs scale16).
1866 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1867 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1868
1869 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1870 return op.emitOpError("scaleA and scaleB must have equal vector length");
1871
1872 // Extract scale format from element types.
1873 Type scaleAElemType = scaleAVecType.getElementType();
1874 Type scaleBElemType = scaleBVecType.getElementType();
1875
1876 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1877 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1878
1879 if (!scaleAFmt || !scaleBFmt)
1880 return op.emitOpError("unsupported scale element types");
1881
1882 // Determine which intrinsic to use based on dimensions.
1883 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1884 std::optional<StringRef> intrinsicName =
1885 getScaledWmmaIntrinsicName(m, n, k, isScale16);
1886 if (!intrinsicName)
1887 return op.emitOpError("unsupported scaled_wmma dimensions: ")
1888 << m << "x" << n << "x" << k;
1889
1890 SmallVector<NamedAttribute, 8> attrs;
1891
1892 // The f4 variant does not have fmtA and fmtB attributes.
1893 bool is32x16 = (m == 32 && n == 16 && k == 128);
1894 if (!is32x16) {
1895 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1896 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1897 }
1898
1899 // modC uses default value of 0.
1900 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1901
1902 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1903 // half of the wave that is being selected (0 or 1).
1904 attrs.emplace_back(
1905 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1906 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1907 attrs.emplace_back(
1908 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1909 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1910
1911 // Reuse flags use default value of false.
1912 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1913 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1914
1915 // Convert typed float vectors to packed format.
1916 Value sourceA =
1917 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1918 Value sourceB =
1919 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1920
1921 // Pack scale vectors into i32/i64.
1922 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1923 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1924
1925 // Create the intrinsic call.
1926 OperationState loweredOp(loc, *intrinsicName);
1927 loweredOp.addTypes(outType);
1928 loweredOp.addOperands(
1929 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1930 loweredOp.addAttributes(attrs);
1931
1932 Operation *lowered = rewriter.create(loweredOp);
1933 rewriter.replaceOp(op, lowered->getResults());
1934
1935 return success();
1936 }
1937};
1938
1939struct TransposeLoadOpLowering
1940 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1941 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1942 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1943
1944 Chipset chipset;
1945
1946 LogicalResult
1947 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1948 ConversionPatternRewriter &rewriter) const override {
1949 if (chipset != kGfx950)
1950 return op.emitOpError("Non-gfx950 chipset not supported");
1951
1952 Location loc = op.getLoc();
1953 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1954
1955 // Elements in subbyte memrefs are stored non-contiguously,
1956 // reject if source is sub-byte memref. Use emulated memrefs instead.
1957 size_t srcElementSize =
1958 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1959 if (srcElementSize < 8)
1960 return op.emitOpError("Expect source memref to have at least 8 bits "
1961 "element size, got ")
1962 << srcElementSize;
1963
1964 auto resultType = cast<VectorType>(op.getResult().getType());
1965 Value srcPtr =
1966 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1967 (adaptor.getSrcIndices()));
1968
1969 size_t numElements = resultType.getNumElements();
1970 size_t elementTypeSize =
1971 resultType.getElementType().getIntOrFloatBitWidth();
1972
1973 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1974 // the element size is smaller than 16 bits.
1975 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1976 rewriter.getIntegerType(32));
1977 Type llvmResultType = typeConverter->convertType(resultType);
1978
1979 switch (elementTypeSize) {
1980 case 4: {
1981 assert(numElements == 16);
1982 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1983 rocdlResultType, srcPtr);
1984 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1985 break;
1986 }
1987 case 6: {
1988 assert(numElements == 16);
1989 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1990 rocdlResultType, srcPtr);
1991 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1992 break;
1993 }
1994 case 8: {
1995 assert(numElements == 8);
1996 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1997 rocdlResultType, srcPtr);
1998 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1999 break;
2000 }
2001 case 16: {
2002 assert(numElements == 4);
2003 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2004 srcPtr);
2005 break;
2006 }
2007 default:
2008 return op.emitOpError("Unsupported element size for transpose load");
2009 }
2010 return success();
2011 }
2012};
2013
2014struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
2015 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2016 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2017
2018 Chipset chipset;
2019
2020 LogicalResult
2021 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2022 ConversionPatternRewriter &rewriter) const override {
2023 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2024 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
2025
2026 Location loc = op.getLoc();
2027
2028 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2029 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2030
2031 // TODO: instead of only transfering one element per thread, we could
2032 // augment it to transfer multiple elements per thread by issuing multiple
2033 // `global_load_lds` instructions.
2034 Type transferType = op.getTransferType();
2035 int loadWidth = [&]() -> int {
2036 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2037 return (transferVectorType.getNumElements() *
2038 transferVectorType.getElementTypeBitWidth()) /
2039 8;
2040 }
2041 return transferType.getIntOrFloatBitWidth() / 8;
2042 }();
2043
2044 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
2045 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2046 return op.emitOpError("chipset unsupported element size");
2047
2048 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2049 return op.emitOpError("Gather to LDS instructions with 12-byte and "
2050 "16-byte load widths are only supported on gfx950");
2051
2052 Value srcPtr =
2053 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2054 (adaptor.getSrcIndices()));
2055 Value dstPtr =
2056 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2057 (adaptor.getDstIndices()));
2058
2059 if (op.getAsync()) {
2060 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2061 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2062 /*offset=*/rewriter.getI32IntegerAttr(0),
2063 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2064 ArrayAttr{});
2065 } else {
2066 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2067 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2068 /*offset=*/rewriter.getI32IntegerAttr(0),
2069 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2070 ArrayAttr{});
2071 }
2072
2073 return success();
2074 }
2075};
2076
2077namespace {
2078struct ExtPackedFp8OpLowering final
2079 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
2080 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2081 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2082 chipset(chipset) {}
2083 Chipset chipset;
2084
2085 LogicalResult
2086 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2087 ConversionPatternRewriter &rewriter) const override;
2088};
2089
2090struct ScaledExtPackedMatrixOpLowering final
2091 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
2092 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
2093 Chipset chipset)
2094 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2095 chipset(chipset) {}
2096 Chipset chipset;
2097
2098 LogicalResult
2099 matchAndRewrite(ScaledExtPackedMatrixOp op,
2100 ScaledExtPackedMatrixOpAdaptor adaptor,
2101 ConversionPatternRewriter &rewriter) const override;
2102};
2103
2104struct PackedTrunc2xFp8OpLowering final
2105 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
2106 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
2107 Chipset chipset)
2108 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2109 chipset(chipset) {}
2110 Chipset chipset;
2111
2112 LogicalResult
2113 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2114 ConversionPatternRewriter &rewriter) const override;
2115};
2116
2117struct PackedStochRoundFp8OpLowering final
2118 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
2119 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
2120 Chipset chipset)
2121 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2122 chipset(chipset) {}
2123 Chipset chipset;
2124
2125 LogicalResult
2126 matchAndRewrite(PackedStochRoundFp8Op op,
2127 PackedStochRoundFp8OpAdaptor adaptor,
2128 ConversionPatternRewriter &rewriter) const override;
2129};
2130
2131struct ScaledExtPackedOpLowering final
2132 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
2133 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2134 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2135 chipset(chipset) {}
2136 Chipset chipset;
2137
2138 LogicalResult
2139 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2140 ConversionPatternRewriter &rewriter) const override;
2141};
2142
2143struct PackedScaledTruncOpLowering final
2144 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
2145 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
2146 Chipset chipset)
2147 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2148 chipset(chipset) {}
2149 Chipset chipset;
2150
2151 LogicalResult
2152 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2153 ConversionPatternRewriter &rewriter) const override;
2154};
2155
2156} // end namespace
2157
2158LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2159 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2160 ConversionPatternRewriter &rewriter) const {
2161 Location loc = op.getLoc();
2162 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2163 return rewriter.notifyMatchFailure(
2164 loc, "Fp8 conversion instructions are not available on target "
2165 "architecture and their emulation is not implemented");
2166 Type v4i8 =
2167 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2168 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2169 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2170
2171 Value source = adaptor.getSource();
2172 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2173 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2174 Type sourceElemType = getElementTypeOrSelf(op.getSource());
2175 // Extend to a v4i8
2176 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2177 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2178 if (!sourceVecType) {
2179 longVec = LLVM::InsertElementOp::create(
2180 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2181 } else {
2182 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2183 Value idx = createI32Constant(rewriter, loc, i);
2184 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2185 longVec =
2186 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2187 }
2188 }
2189 source = longVec;
2190 }
2191 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2192 if (resultVecType) {
2193 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2194 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2195 op.getIndex());
2196 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2197 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2198 op.getIndex());
2199 }
2200 } else {
2201 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2202 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2203 op.getIndex());
2204 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2205 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2206 op.getIndex());
2207 }
2208 }
2209 return success();
2210}
2211
2212int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
2213 int32_t firstScaleByte) {
2214 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
2215 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
2216 // firstScaleByte are merged into a single attribute scaleSel. This is how
2217 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
2218 // attribute but is derifed from firstScaleLane).
2219 assert(llvm::is_contained({16, 32}, blockSize));
2220 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2221
2222 const bool isFp8 = bitWidth == 8;
2223 const bool isBlock16 = blockSize == 16;
2224
2225 if (!isFp8) {
2226 int32_t bit0 = isBlock16;
2227 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2228 int32_t bit1 = (firstScaleByte == 2) << 1;
2229 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2230 int32_t bit2 = scaleWaveHalf << 2;
2231 return bit2 | bit1 | bit0;
2232 }
2233
2234 int32_t bit0 = isBlock16;
2235 // firstScaleByte is guaranteed to be defined by two bits.
2236 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2237 int32_t bits2and1 = firstScaleByte << 1;
2238 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2239 int32_t bit3 = scaleWaveHalf << 3;
2240 int32_t bits = bit3 | bits2and1 | bit0;
2241 // These are invalid cases.
2242 assert(!llvm::is_contained(
2243 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2244 return bits;
2245}
2246
2247static std::optional<StringRef>
2248scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2249 using fp4 = Float4E2M1FNType;
2250 using fp8 = Float8E4M3FNType;
2251 using bf8 = Float8E5M2Type;
2252 using fp6 = Float6E2M3FNType;
2253 using bf6 = Float6E3M2FNType;
2254 if (isa<fp4>(srcElemType)) {
2255 if (destElemType.isF16())
2256 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2257 if (destElemType.isBF16())
2258 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2259 if (destElemType.isF32())
2260 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2261 return std::nullopt;
2262 }
2263 if (isa<fp8>(srcElemType)) {
2264 if (destElemType.isF16())
2265 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2266 if (destElemType.isBF16())
2267 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2268 if (destElemType.isF32())
2269 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2270 return std::nullopt;
2271 }
2272 if (isa<bf8>(srcElemType)) {
2273 if (destElemType.isF16())
2274 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2275 if (destElemType.isBF16())
2276 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2277 if (destElemType.isF32())
2278 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2279 return std::nullopt;
2280 }
2281 if (isa<fp6>(srcElemType)) {
2282 if (destElemType.isF16())
2283 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2284 if (destElemType.isBF16())
2285 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2286 if (destElemType.isF32())
2287 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2288 return std::nullopt;
2289 }
2290 if (isa<bf6>(srcElemType)) {
2291 if (destElemType.isF16())
2292 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2293 if (destElemType.isBF16())
2294 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2295 if (destElemType.isF32())
2296 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2297 return std::nullopt;
2298 }
2299 llvm_unreachable("invalid combination of element types for packed conversion "
2300 "instructions");
2301}
2302
2303LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2304 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2305 ConversionPatternRewriter &rewriter) const {
2306 using fp4 = Float4E2M1FNType;
2307 using fp8 = Float8E4M3FNType;
2308 using bf8 = Float8E5M2Type;
2309 using fp6 = Float6E2M3FNType;
2310 using bf6 = Float6E3M2FNType;
2311 Location loc = op.getLoc();
2312 if (chipset != kGfx1250) {
2313 return rewriter.notifyMatchFailure(
2314 loc,
2315 "Scaled fp packed conversion instructions are not available on target "
2316 "architecture and their emulation is not implemented");
2317 }
2318 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2319 // is being selected.
2320 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2321 int32_t firstScaleByte = op.getFirstScaleByte();
2322 int32_t blockSize = op.getBlockSize();
2323 auto sourceType = cast<VectorType>(op.getSource().getType());
2324 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2325 unsigned bitWidth = srcElemType.getWidth();
2326
2327 auto targetType = cast<VectorType>(op.getResult().getType());
2328 auto destElemType = cast<FloatType>(targetType.getElementType());
2329
2330 IntegerType i32 = rewriter.getI32Type();
2331 Value source = adaptor.getSource();
2332 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2333 Type packedType = nullptr;
2334 if (isa<fp4>(srcElemType)) {
2335 packedType = i32;
2336 packedType = getTypeConverter()->convertType(packedType);
2337 } else if (isa<fp8, bf8>(srcElemType)) {
2338 packedType = VectorType::get(2, i32);
2339 packedType = getTypeConverter()->convertType(packedType);
2340 } else if (isa<fp6, bf6>(srcElemType)) {
2341 packedType = VectorType::get(3, i32);
2342 packedType = getTypeConverter()->convertType(packedType);
2343 } else {
2344 llvm_unreachable("invalid element type for packed scaled ext");
2345 }
2346
2347 if (!packedType || !llvmResultType) {
2348 return rewriter.notifyMatchFailure(op, "type conversion failed");
2349 }
2350
2351 std::optional<StringRef> maybeIntrinsic =
2352 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2353 if (!maybeIntrinsic.has_value())
2354 return op.emitOpError(
2355 "no intrinsic matching packed scaled conversion on the given chipset");
2356
2357 int32_t scaleSel =
2358 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2359 Value castedScale =
2360 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2361 Value castedSource =
2362 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2363
2364 OperationState loweredOp(loc, *maybeIntrinsic);
2365 loweredOp.addTypes({llvmResultType});
2366 loweredOp.addOperands({castedSource, castedScale});
2367
2368 SmallVector<NamedAttribute, 1> attrs;
2369 attrs.push_back(
2370 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2371
2372 loweredOp.addAttributes(attrs);
2373 Operation *lowered = rewriter.create(loweredOp);
2374 rewriter.replaceOp(op, lowered);
2375
2376 return success();
2377}
2378
2379LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2380 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2381 ConversionPatternRewriter &rewriter) const {
2382 Location loc = op.getLoc();
2383 if (chipset != kGfx950)
2384 return rewriter.notifyMatchFailure(
2385 loc, "Scaled fp conversion instructions are not available on target "
2386 "architecture and their emulation is not implemented");
2387 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2388
2389 Value source = adaptor.getSource();
2390 Value scale = adaptor.getScale();
2391
2392 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2393 Type sourceElemType = sourceVecType.getElementType();
2394 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2395 Type destElemType = destVecType.getElementType();
2396
2397 VectorType packedVecType;
2398 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2399 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2400 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2401 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2402 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2403 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2404 } else {
2405 llvm_unreachable("invalid element type for scaled ext");
2406 }
2407
2408 // Extend to a packedVectorType
2409 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2410 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2411 if (!sourceVecType) {
2412 longVec = LLVM::InsertElementOp::create(
2413 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2414 } else {
2415 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2416 Value idx = createI32Constant(rewriter, loc, i);
2417 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2418 longVec =
2419 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2420 }
2421 }
2422 source = longVec;
2423 }
2424 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2425
2426 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2427 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2428 op, destVecType, i32Source, scale, op.getIndex());
2429 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2430 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2431 op, destVecType, i32Source, scale, op.getIndex());
2432 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2433 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2434 op, destVecType, i32Source, scale, op.getIndex());
2435 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2436 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2437 op, destVecType, i32Source, scale, op.getIndex());
2438 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2439 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2440 op, destVecType, i32Source, scale, op.getIndex());
2441 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2442 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2443 op, destVecType, i32Source, scale, op.getIndex());
2444 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2445 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2446 op, destVecType, i32Source, scale, op.getIndex());
2447 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2448 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2449 op, destVecType, i32Source, scale, op.getIndex());
2450 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2451 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2452 op, destVecType, i32Source, scale, op.getIndex());
2453 else
2454 return failure();
2455
2456 return success();
2457}
2458
2459LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2460 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2461 ConversionPatternRewriter &rewriter) const {
2462 Location loc = op.getLoc();
2463 if (chipset != kGfx950)
2464 return rewriter.notifyMatchFailure(
2465 loc, "Scaled fp conversion instructions are not available on target "
2466 "architecture and their emulation is not implemented");
2467 Type v2i16 = getTypeConverter()->convertType(
2468 VectorType::get(2, rewriter.getI16Type()));
2469 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2470
2471 Type resultType = op.getResult().getType();
2472 Type resultElemType = getElementTypeOrSelf(resultType);
2473 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2474 Type sourceElemType = sourceVecType.getElementType();
2475
2476 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2477
2478 Value source = adaptor.getSource();
2479 Value scale = adaptor.getScale();
2480 Value existing = adaptor.getExisting();
2481 if (existing)
2482 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2483 else
2484 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2485
2486 if (sourceVecType.getNumElements() < 2) {
2487 Value c0 = createI32Constant(rewriter, loc, 0);
2488 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2489 VectorType v2 = VectorType::get(2, sourceElemType);
2490 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2491 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2492 }
2493
2494 Value sourceA, sourceB;
2495 if (sourceElemType.isF32()) {
2496 Value c0 = createI32Constant(rewriter, loc, 0);
2497 Value c1 = createI32Constant(rewriter, loc, 1);
2498 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2499 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2500 }
2501
2502 Value result;
2503 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2504 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2505 existing, sourceA, sourceB,
2506 scale, op.getIndex());
2507 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2508 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2509 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2510 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2511 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2512 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2513 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2514 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2515 existing, sourceA, sourceB,
2516 scale, op.getIndex());
2517 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2518 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2519 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2520 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2521 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2522 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2523 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2524 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2525 existing, sourceA, sourceB,
2526 scale, op.getIndex());
2527 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2528 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2529 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2530 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2531 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2532 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2533 else
2534 return failure();
2535
2536 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2537 op, getTypeConverter()->convertType(resultType), result);
2538 return success();
2539}
2540
2541LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2542 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2543 ConversionPatternRewriter &rewriter) const {
2544 Location loc = op.getLoc();
2545 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2546 return rewriter.notifyMatchFailure(
2547 loc, "Fp8 conversion instructions are not available on target "
2548 "architecture and their emulation is not implemented");
2549 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2550
2551 Type resultType = op.getResult().getType();
2552 Type resultElemType = getElementTypeOrSelf(resultType);
2553
2554 Value sourceA = adaptor.getSourceA();
2555 Value sourceB = adaptor.getSourceB();
2556 if (!sourceB)
2557 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2558 Value existing = adaptor.getExisting();
2559 if (existing)
2560 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2561 else
2562 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2563
2564 Value result;
2565 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2566 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2567 existing, op.getWordIndex());
2568 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2569 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2570 existing, op.getWordIndex());
2571
2572 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2573 op, getTypeConverter()->convertType(resultType), result);
2574 return success();
2575}
2576
2577LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2578 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2579 ConversionPatternRewriter &rewriter) const {
2580 Location loc = op.getLoc();
2581 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2582 return rewriter.notifyMatchFailure(
2583 loc, "Fp8 conversion instructions are not available on target "
2584 "architecture and their emulation is not implemented");
2585 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2586
2587 Type resultType = op.getResult().getType();
2588 Type resultElemType = getElementTypeOrSelf(resultType);
2589
2590 Value source = adaptor.getSource();
2591 Value stoch = adaptor.getStochiasticParam();
2592 Value existing = adaptor.getExisting();
2593 if (existing)
2594 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2595 else
2596 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2597
2598 Value result;
2599 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2600 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2601 existing, op.getStoreIndex());
2602 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2603 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2604 existing, op.getStoreIndex());
2605
2606 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2607 op, getTypeConverter()->convertType(resultType), result);
2608 return success();
2609}
2610
2611// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2612// operation into the corresponding ROCDL instructions.
2613struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2614 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2615 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2616 Chipset chipset;
2617
2618 LogicalResult
2619 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2620 ConversionPatternRewriter &rewriter) const override {
2621
2622 // Convert the source operand to the corresponding LLVM type
2623 Location loc = DppOp.getLoc();
2624 Value src = adaptor.getSrc();
2625 Value old = adaptor.getOld();
2626 Type srcType = src.getType();
2627 Type oldType = old.getType();
2628 Type llvmType = nullptr;
2629 if (srcType.getIntOrFloatBitWidth() < 32) {
2630 llvmType = rewriter.getI32Type();
2631 } else if (isa<FloatType>(srcType)) {
2632 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2633 ? rewriter.getF32Type()
2634 : rewriter.getF64Type();
2635 } else if (isa<IntegerType>(srcType)) {
2636 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2637 ? rewriter.getI32Type()
2638 : rewriter.getI64Type();
2639 }
2640 auto llvmSrcIntType = typeConverter->convertType(
2641 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2642
2643 // If the source type is less of 32, use bitcast to convert it to i32.
2644 auto convertOperand = [&](Value operand, Type operandType) {
2645 if (operandType.getIntOrFloatBitWidth() <= 16) {
2646 if (llvm::isa<FloatType>(operandType)) {
2647 operand =
2648 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2649 }
2650 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2651 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2652 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2653 operand =
2654 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2655 createI32Constant(rewriter, loc, 0));
2656 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2657 }
2658 return operand;
2659 };
2660
2661 src = convertOperand(src, srcType);
2662 old = convertOperand(old, oldType);
2663
2664 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2665 enum DppCtrl : unsigned {
2666 ROW_SHL0 = 0x100,
2667 ROW_SHR0 = 0x110,
2668 ROW_ROR0 = 0x120,
2669 WAVE_SHL1 = 0x130,
2670 WAVE_ROL1 = 0x134,
2671 WAVE_SHR1 = 0x138,
2672 WAVE_ROR1 = 0x13C,
2673 ROW_MIRROR = 0x140,
2674 ROW_HALF_MIRROR = 0x141,
2675 BCAST15 = 0x142,
2676 BCAST31 = 0x143,
2677 };
2678
2679 auto kind = DppOp.getKind();
2680 auto permArgument = DppOp.getPermArgument();
2681 uint32_t DppCtrl = 0;
2682
2683 switch (kind) {
2684
2685 case DPPPerm::quad_perm: {
2686 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2687 int32_t i = 0;
2688 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2689 uint32_t num = elem.getInt();
2690 DppCtrl |= num << (i * 2);
2691 i++;
2692 }
2693 break;
2694 }
2695 case DPPPerm::row_shl: {
2696 auto intAttr = cast<IntegerAttr>(*permArgument);
2697 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2698 break;
2699 }
2700 case DPPPerm::row_shr: {
2701 auto intAttr = cast<IntegerAttr>(*permArgument);
2702 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2703 break;
2704 }
2705 case DPPPerm::row_ror: {
2706 auto intAttr = cast<IntegerAttr>(*permArgument);
2707 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2708 break;
2709 }
2710 case DPPPerm::wave_shl:
2711 DppCtrl = DppCtrl::WAVE_SHL1;
2712 break;
2713 case DPPPerm::wave_shr:
2714 DppCtrl = DppCtrl::WAVE_SHR1;
2715 break;
2716 case DPPPerm::wave_rol:
2717 DppCtrl = DppCtrl::WAVE_ROL1;
2718 break;
2719 case DPPPerm::wave_ror:
2720 DppCtrl = DppCtrl::WAVE_ROR1;
2721 break;
2722 case DPPPerm::row_mirror:
2723 DppCtrl = DppCtrl::ROW_MIRROR;
2724 break;
2725 case DPPPerm::row_half_mirror:
2726 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2727 break;
2728 case DPPPerm::row_bcast_15:
2729 DppCtrl = DppCtrl::BCAST15;
2730 break;
2731 case DPPPerm::row_bcast_31:
2732 DppCtrl = DppCtrl::BCAST31;
2733 break;
2734 }
2735
2736 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2737 // constants
2738 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2739 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2740 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2741
2742 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2743 auto dppMovOp =
2744 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2745 rowMask, bankMask, boundCtrl);
2746
2747 Value result = dppMovOp.getRes();
2748 if (srcType.getIntOrFloatBitWidth() < 32) {
2749 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2750 if (!llvm::isa<IntegerType>(srcType)) {
2751 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2752 }
2753 }
2754
2755 // We are replacing the AMDGPU_DPPOp instruction with the new
2756 // ROCDL_DPPMovOp instruction
2757 rewriter.replaceOp(DppOp, ValueRange(result));
2758 return success();
2759 }
2760};
2761
2762struct AMDGPUSwizzleBitModeLowering
2763 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2765
2766 LogicalResult
2767 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2768 ConversionPatternRewriter &rewriter) const override {
2769 Location loc = op.getLoc();
2770 Type i32 = rewriter.getI32Type();
2771 Value src = adaptor.getSrc();
2772 SmallVector<Value> decomposed;
2773 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2774 return rewriter.notifyMatchFailure(op,
2775 "failed to decompose value to i32");
2776 unsigned andMask = op.getAndMask();
2777 unsigned orMask = op.getOrMask();
2778 unsigned xorMask = op.getXorMask();
2779
2780 // bit 15 is 0 for the BitMode swizzle.
2781 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2782 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2783 Value maskValue = createI32Constant(rewriter, loc, mask);
2784 SmallVector<Value> swizzled;
2785 for (Value v : decomposed) {
2786 Value res =
2787 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2788 swizzled.emplace_back(res);
2789 }
2790
2791 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2792 rewriter.replaceOp(op, result);
2793 return success();
2794 }
2795};
2796
2797struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2799
2800 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2801 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2802 Chipset chipset;
2803
2804 LogicalResult
2805 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2806 ConversionPatternRewriter &rewriter) const override {
2807 if (chipset < kGfx950)
2808 return op->emitOpError("permlane_swap is only supported on gfx950+");
2809
2810 Location loc = op.getLoc();
2811 Type i32 = rewriter.getI32Type();
2812 Value src = adaptor.getSrc();
2813 unsigned rowLength = op.getRowLength();
2814 bool fi = op.getFetchInactive();
2815 bool boundctrl = op.getBoundCtrl();
2816
2817 SmallVector<Value> decomposed;
2818 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2819 return rewriter.notifyMatchFailure(op,
2820 "failed to decompose value to i32");
2821
2822 SmallVector<Value> permuted;
2823 for (Value v : decomposed) {
2824 Value res;
2825 Type i32pair = LLVM::LLVMStructType::getLiteral(
2826 rewriter.getContext(), {v.getType(), v.getType()});
2827
2828 if (rowLength == 16)
2829 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2830 boundctrl);
2831 else if (rowLength == 32)
2832 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2833 boundctrl);
2834 else
2835 llvm_unreachable("unsupported row length");
2836
2837 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2838 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2839
2840 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2841 LLVM::ICmpPredicate::eq, vdst0, v);
2842
2843 // Per `permlane(16|32)` semantics: if the first extracted element equals
2844 // 'v', the result is the second element; otherwise it is the first.
2845 Value vdstNew =
2846 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2847 permuted.emplace_back(vdstNew);
2848 }
2849
2850 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2851 rewriter.replaceOp(op, result);
2852 return success();
2853 }
2854};
2855
2856//===----------------------------------------------------------------------===//
2857// In-LDS Barrier Operations
2858//===----------------------------------------------------------------------===//
2859
2860// Bit layout of ds_barrier_state (as i64):
2861// [63:32] init count (32 bits)
2862// [31:29] phase (3 bits)
2863// [28:0] pending count (29 bits)
2864constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
2865constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
2866constexpr int32_t kDsBarrierInitCountPos = 32;
2867constexpr int32_t kDsBarrierPendingCountMask =
2868 (1 << kDsBarrierPendingCountBitWidth) - 1;
2869
2870struct DsBarrierInitOpLowering
2871 : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
2872 Chipset chipset;
2873
2874 DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2875 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
2876
2877 LogicalResult
2878 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
2879 ConversionPatternRewriter &rewriter) const override {
2880 if (chipset < kGfx1250)
2881 return op->emitOpError("only supported on gfx1250+");
2882
2883 Location loc = op.getLoc();
2884 Type i64 = rewriter.getI64Type();
2885
2886 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2887 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2888 adaptor.getBase(), adaptor.getIndices());
2889
2890 // Note: We give participants as the number of arrivals that have to occur
2891 // before the phase changes. Hardware changes the phase when updating the
2892 // pending count would underflow, so we subtract 1 to get the behavior we're
2893 // looking for.
2894 Value initCount =
2895 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
2896 createI32Constant(rewriter, loc, 1));
2897
2898 // Just a bit of paranoia, but this also allows for configurable width if
2899 // that becomes a thing.
2900 Value countMask =
2901 createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
2902 Value maskedCount32 =
2903 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
2904 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
2905
2906 Value initCountShifted = LLVM::ShlOp::create(
2907 rewriter, loc, maskedCount,
2908 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
2909 Value barrierState =
2910 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
2911
2912 LLVM::StoreOp::create(
2913 rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
2914 /*isNonTemporal=*/false,
2915 /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
2916 /*syncscope=*/"workgroup");
2917
2918 rewriter.eraseOp(op);
2919 return success();
2920 }
2921};
2922
2923struct DsBarrierPollStateOpLowering
2924 : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
2925 Chipset chipset;
2926
2927 DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
2928 Chipset chipset)
2929 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
2930 chipset(chipset) {}
2931
2932 LogicalResult
2933 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
2934 ConversionPatternRewriter &rewriter) const override {
2935 if (chipset < kGfx1250)
2936 return op->emitOpError("only supported on gfx1250+");
2937
2938 Location loc = op.getLoc();
2939 Type i64 = rewriter.getI64Type();
2940
2941 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2942 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2943 adaptor.getBase(), adaptor.getIndices());
2944
2945 // Atomic load with workgroup scope and acquire ordering should be what
2946 // we're looking for.
2947 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
2948 op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
2949 /*nontemporal=*/false, /*invariant=*/false,
2950 /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
2951 /*syncscope=*/"workgroup");
2952 return success();
2953 }
2954};
2955
2956struct DsAsyncBarrierArriveOpLowering
2957 : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
2958 Chipset chipset;
2959
2960 DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
2961 Chipset chipset)
2962 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
2963 chipset(chipset) {}
2964
2965 LogicalResult
2966 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
2967 ConversionPatternRewriter &rewriter) const override {
2968 if (chipset < kGfx1250)
2969 return op->emitOpError("only supported on gfx1250+");
2970
2971 Location loc = op.getLoc();
2972
2973 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2974 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2975 adaptor.getBase(), adaptor.getIndices());
2976
2977 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
2978 op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
2979 /*tbaa=*/nullptr);
2980 return success();
2981 }
2982};
2983
2984struct DsBarrierArriveOpLowering
2985 : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
2986 Chipset chipset;
2987
2988 DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2989 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
2990 }
2991
2992 LogicalResult
2993 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
2994 ConversionPatternRewriter &rewriter) const override {
2995 if (chipset < kGfx1250)
2996 return op->emitOpError("only supported on gfx1250+");
2997
2998 Location loc = op.getLoc();
2999 Type i64 = rewriter.getI64Type();
3000
3001 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3002 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3003 adaptor.getBase(), adaptor.getIndices());
3004
3005 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3006 op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
3007 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3008 return success();
3009 }
3010};
3011
3012struct DsBarrierStatePhaseOpLowering
3013 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3015
3016 LogicalResult
3017 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3018 ConversionPatternRewriter &rewriter) const override {
3019 Location loc = op.getLoc();
3020 Type i32 = rewriter.getI32Type();
3021
3022 Value state = adaptor.getState();
3023
3024 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3025 Value phase = LLVM::LShrOp::create(
3026 rewriter, loc, noInitCount,
3027 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3028
3029 rewriter.replaceOp(op, phase);
3030 return success();
3031 }
3032};
3033
3034struct DsBarrierStatePendingCountOpLowering
3035 : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3037
3038 LogicalResult
3039 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3040 ConversionPatternRewriter &rewriter) const override {
3041 Location loc = op.getLoc();
3042 Type i32 = rewriter.getI32Type();
3043
3044 Value state = adaptor.getState();
3045
3046 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3047 Value pendingCount = LLVM::AndOp::create(
3048 rewriter, loc, noInitCount,
3049 createI32Constant(rewriter, loc,
3050 static_cast<uint32_t>(kDsBarrierPendingCountMask)));
3051
3052 rewriter.replaceOp(op, pendingCount);
3053 return success();
3054 }
3055};
3056
3057struct DsBarrierStateInitCountOpLowering
3058 : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3060
3061 LogicalResult
3062 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3063 ConversionPatternRewriter &rewriter) const override {
3064 Location loc = op.getLoc();
3065 Type i32 = rewriter.getI32Type();
3066
3067 Value state = adaptor.getState();
3068
3069 Value initCountI64 = LLVM::LShrOp::create(
3070 rewriter, loc, state,
3071 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3072 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3073
3074 rewriter.replaceOp(op, initCount);
3075 return success();
3076 }
3077};
3078
3079struct DsBarrierStatePhaseParityLowering
3080 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3082
3083 LogicalResult
3084 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3085 ConversionPatternRewriter &rewriter) const override {
3086 Location loc = op.getLoc();
3087 Type i1 = rewriter.getI1Type();
3088
3089 Value state = adaptor.getState();
3090
3091 Value noInitCount =
3092 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3093 Value phase = LLVM::LShrOp::create(
3094 rewriter, loc, noInitCount,
3095 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3096 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3097
3098 rewriter.replaceOp(op, parity);
3099 return success();
3100 }
3101};
3102
3103//===----------------------------------------------------------------------===//
3104// Tensor Data Mover (TDM)
3105//===----------------------------------------------------------------------===//
3106
3107static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3108 Value accumulator, Value value, int64_t shift) {
3109 shift = shift % 32;
3110 Value shiftAmount;
3111 if (shift != 0) {
3112 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
3113 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3114 }
3115
3116 if (matchPattern(accumulator, mlir::m_Zero()))
3117 return value;
3118
3119 constexpr bool isDisjoint = true;
3120 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3121}
3122
3123template <typename BaseOp>
3124struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
3125 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3126 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
3127
3128 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
3129 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3130 Chipset chipset;
3131
3132 LogicalResult
3133 matchAndRewrite(BaseOp op, Adaptor adaptor,
3134 ConversionPatternRewriter &rewriter) const override {
3135 if (chipset < kGfx1250)
3136 return op->emitOpError("make_dma_base is only supported on gfx1250");
3137
3138 Location loc = op.getLoc();
3139
3140 constexpr int32_t constlen = 4;
3141 Value consts[constlen];
3142 for (int64_t i = 0; i < constlen; ++i)
3143 consts[i] = createI32Constant(rewriter, loc, i);
3144
3145 constexpr int32_t sgprslen = constlen;
3146 Value sgprs[sgprslen];
3147 for (int64_t i = 0; i < sgprslen; ++i) {
3148 sgprs[i] = consts[0];
3149 }
3150
3151 sgprs[0] = consts[1];
3152
3153 if constexpr (BaseOp::isGather()) {
3154 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3155
3156 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3157 Type indexType = type.getIndexType();
3158 unsigned indexSize = indexType.getIntOrFloatBitWidth();
3159 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3160 "expected index_size to be 16 or 32");
3161 unsigned idx = (indexSize / 16) - 1;
3162
3163 if (idx)
3164 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3165 }
3166
3167 ValueRange ldsIndices = adaptor.getLdsIndices();
3168 Value lds = adaptor.getLds();
3169 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3170
3172 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3173
3174 ValueRange globalIndices = adaptor.getGlobalIndices();
3175 Value global = adaptor.getGlobal();
3176 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3177
3179 rewriter, loc, globalMemRefType, global, globalIndices);
3180
3181 Type i32 = rewriter.getI32Type();
3182 Type i64 = rewriter.getI64Type();
3183
3184 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3185 Value castForGlobalAddr =
3186 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3187
3188 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3189
3190 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3191 createI64Constant(rewriter, loc, 32));
3192
3193 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3194
3195 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
3196 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3197
3198 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3199
3200 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3201 assert(v4i32 && "expected type conversion to succeed");
3202 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3203
3204 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3205 result =
3206 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
3207
3208 rewriter.replaceOp(op, result);
3209 return success();
3210 }
3211};
3212
3213template <typename DescriptorOp>
3214struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
3215 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3216 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
3217
3218 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
3219 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3220 Chipset chipset;
3221
3222 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
3223
3224 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3225 ConversionPatternRewriter &rewriter, Location loc,
3226 Value sgpr0) const {
3227 Value mask = op.getWorkgroupMask();
3228 if (!mask)
3229 return sgpr0;
3230
3231 Type i16 = rewriter.getI16Type();
3232 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3233 Type i32 = rewriter.getI32Type();
3234 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3235 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3236 }
3237
3238 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3239 ConversionPatternRewriter &rewriter, Location loc,
3240 Value sgpr0, ArrayRef<Value> consts) const {
3241 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3242 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3243 "expected type width to be 8, 16, 32, or 64.");
3244 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3245 Value size = consts[idx];
3246 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3247 }
3248
3249 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3250 ConversionPatternRewriter &rewriter, Location loc,
3251 Value sgpr0, ArrayRef<Value> consts) const {
3252 if (!adaptor.getAtomicBarrierAddress())
3253 return sgpr0;
3254
3255 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3256 }
3257
3258 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3259 ConversionPatternRewriter &rewriter, Location loc,
3260 Value sgpr0, ArrayRef<Value> consts) const {
3261 if (!adaptor.getGlobalIncrement())
3262 return sgpr0;
3263
3264 // Value is ignored when in gather mode.
3265 // TODO: emit error earlier?
3266 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3267 }
3268
3269 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3270 ConversionPatternRewriter &rewriter, Location loc,
3271 Value sgpr0, ArrayRef<Value> consts) const {
3272 if (!op.getPadAmount())
3273 return sgpr0;
3274
3275 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3276 }
3277
3278 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3279 ConversionPatternRewriter &rewriter, Location loc,
3280 Value sgpr0, ArrayRef<Value> consts) const {
3281 if (!op.getWorkgroupMask())
3282 return sgpr0;
3283
3284 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3285 }
3286
3287 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3288 ConversionPatternRewriter &rewriter, Location loc,
3289 Value sgpr0, ArrayRef<Value> consts) const {
3290 if (!op.getPadAmount())
3291 return sgpr0;
3292
3293 // pre-condition: padInterval can be a power of two between 2 and 256.
3294 // TODO: Validation if the value breaks the pre-condition.
3295 // If the pre-condition fails, there is a possibility of
3296 // affecting the higher bits. In a following PR implement
3297 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3298 // checked at runtime.
3299 IntegerType i32 = rewriter.getI32Type();
3300 Value padInterval = adaptor.getPadInterval();
3301 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3302 padInterval, false);
3303 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3304 // post-condition: padInterval can be a value between 0 and 7.
3305 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3306 }
3307
3308 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3309 ConversionPatternRewriter &rewriter, Location loc,
3310 Value sgpr0, ArrayRef<Value> consts) const {
3311 if (!op.getPadAmount())
3312 return sgpr0;
3313
3314 // pre-condition: padAmount is a value between 1-128.
3315 // TODO: Validation if the value breaks the pre-condition.
3316 // If the pre-condition fails, there is a possibility of
3317 // affecting the higher bits. In a following PR implement
3318 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3319 // checked at runtime.
3320 Value padAmount = adaptor.getPadAmount();
3321 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3322 // post-condition: padAmount is a value between 0-127.
3323 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3324 }
3325
3326 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3327 ConversionPatternRewriter &rewriter,
3328 Location loc, Value sgpr1,
3329 ArrayRef<Value> consts) const {
3330 if (!adaptor.getAtomicBarrierAddress())
3331 return sgpr1;
3332
3333 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3334 auto barrierAddressTy =
3335 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3336 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3337 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
3338 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3339 atomicBarrierIndices);
3340 IntegerType i32 = rewriter.getI32Type();
3341 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
3342 // that the 3 LSBs are zero.
3343 // TODO: Validation if the value breaks the pre-condition.
3344 // In a following PR implement RuntimeVerifiableOpInterface
3345 // that instruments conditions that need to be checked at runtime.
3346 atomicBarrierAddress =
3347 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3348 atomicBarrierAddress =
3349 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3350 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
3351 atomicBarrierAddress =
3352 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3353 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3354 }
3355
3356 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3357 ConversionPatternRewriter &rewriter,
3358 Location loc, Value sgpr1, Value sgpr2,
3359 ArrayRef<Value> consts, uint64_t dimX,
3360 uint32_t offset) const {
3361 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3362 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3363 SmallVector<OpFoldResult> mixedGlobalSizes =
3364 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3365 if (mixedGlobalSizes.size() <= dimX)
3366 return {sgpr1, sgpr2};
3367
3368 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3369 // pre-condition: tensorDimX is less than 2^32-1
3370 // TODO: Validation if the value breaks the pre-condition.
3371 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3372 // conditions that need to be checked at runtime. This could also be fixed
3373 // by saying that mixedGlobalSizes is a DynamicI32List.
3374 Value tensorDimX;
3375 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3376 tensorDimX =
3377 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3378 } else {
3379 IntegerType i32 = rewriter.getI32Type();
3380 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3381 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3382 }
3383
3384 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3385
3386 Value c16 = createI32Constant(rewriter, loc, 16);
3387 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3388 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3389 return {sgpr1, sgpr2};
3390 }
3391
3392 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3393 ConversionPatternRewriter &rewriter,
3394 Location loc, Value sgpr1, Value sgpr2,
3395 ArrayRef<Value> consts) const {
3396 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3397 48);
3398 }
3399
3400 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3401 ConversionPatternRewriter &rewriter,
3402 Location loc, Value sgpr2, Value sgpr3,
3403 ArrayRef<Value> consts) const {
3404 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3405 80);
3406 }
3407
3408 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3409 ConversionPatternRewriter &rewriter, Location loc,
3410 Value sgpr, ArrayRef<Value> consts, size_t dimX,
3411 int64_t offset) const {
3412 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3413 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3414 SmallVector<OpFoldResult> mixedSharedSizes =
3415 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
3416 if (mixedSharedSizes.size() <= dimX)
3417 return sgpr;
3418
3419 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3420 // pre-condition: tileDimX is less than 2^16-1
3421 // TODO: Validation if the value breaks the pre-condition.
3422 // If the pre-condition fails, there is a possibility of
3423 // affecting the higher bits. In a following PR implement
3424 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3425 // checked at runtime. This could also be fixed by saying that
3426 // mixedSharedSizes is a DynamicI16List.
3427 Value tileDimX;
3428 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3429 tileDimX =
3430 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3431 } else {
3432 IntegerType i32 = rewriter.getI32Type();
3433 tileDimX = cast<Value>(tileDimXOpFoldResult);
3434 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3435 }
3436
3437 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3438 }
3439
3440 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3441 ConversionPatternRewriter &rewriter, Location loc,
3442 Value sgpr3, ArrayRef<Value> consts) const {
3443 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3444 }
3445
3446 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3447 ConversionPatternRewriter &rewriter, Location loc,
3448 Value sgpr4, ArrayRef<Value> consts) const {
3449 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3450 }
3451
3452 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3453 ConversionPatternRewriter &rewriter, Location loc,
3454 Value sgpr4, ArrayRef<Value> consts) const {
3455 auto type = cast<VectorType>(op.getIndices().getType());
3456 ArrayRef<int64_t> shape = type.getShape();
3457 assert(shape.size() == 1 && "expected shape to be of rank 1.");
3458 unsigned length = shape.back();
3459 assert(0 < length && length <= 16 && "expected length to be at most 16.");
3460 Value value = createI32Constant(rewriter, loc, length);
3461 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3462 }
3463
3464 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3465 ConversionPatternRewriter &rewriter,
3466 Location loc, Value sgpr4,
3467 ArrayRef<Value> consts) const {
3468 if constexpr (DescriptorOp::isGather())
3469 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3470 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3471 }
3472
3473 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3474 ConversionPatternRewriter &rewriter, Location loc,
3475 Value sgpr4, ArrayRef<Value> consts) const {
3476 // Value is ignored when in gather mode.
3477 if constexpr (DescriptorOp::isGather())
3478 return sgpr4;
3479 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3480 }
3481
3482 std::pair<Value, Value>
3483 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3484 ConversionPatternRewriter &rewriter, Location loc,
3485 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3486 size_t dimX, int64_t offset) const {
3487 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3488 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3489 SmallVector<OpFoldResult> mixedGlobalStrides =
3490 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3491
3492 if (mixedGlobalStrides.size() <= (dimX + 1))
3493 return {sgprY, sgprZ};
3494
3495 OpFoldResult tensorDimXStrideOpFoldResult =
3496 *(mixedGlobalStrides.rbegin() + dimX + 1);
3497 // pre-condition: tensorDimXStride is less than 2^48-1
3498 // TODO: Validation if the value breaks the pre-condition.
3499 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3500 // conditions that need to be checked at runtime.
3501 Value tensorDimXStride;
3502 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3503 tensorDimXStride =
3504 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3505 else
3506 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3507
3508 constexpr int64_t first48bits = (1ll << 48) - 1;
3509 Value mask = createI64Constant(rewriter, loc, first48bits);
3510 tensorDimXStride =
3511 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3512 IntegerType i32 = rewriter.getI32Type();
3513 Value tensorDimXStrideLow =
3514 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3515 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3516
3517 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3518 Value shiftVal = createI64Constant(rewriter, loc, shift);
3519 Value tensorDimXStrideHigh =
3520 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3521 tensorDimXStrideHigh =
3522 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3523 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3524 offset + shift);
3525 return {sgprY, sgprZ};
3526 }
3527
3528 std::pair<Value, Value>
3529 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3530 ConversionPatternRewriter &rewriter, Location loc,
3531 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3532 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3533 0, 160);
3534 }
3535
3536 std::pair<Value, Value>
3537 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3538 ConversionPatternRewriter &rewriter, Location loc,
3539 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3540 // Value is ignored when in gather mode.
3541 if constexpr (DescriptorOp::isGather())
3542 return {sgpr5, sgpr6};
3543 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3544 1, 208);
3545 }
3546
3547 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3548 ConversionPatternRewriter &rewriter, Location loc,
3549 ArrayRef<Value> consts) const {
3550 Value sgprs[8];
3551 for (int64_t i = 0; i < 8; ++i) {
3552 sgprs[i] = consts[0];
3553 }
3554
3555 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3556 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3557 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3558 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3559 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3560 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3561 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3562 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3563
3564 sgprs[1] =
3565 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3566 std::tie(sgprs[1], sgprs[2]) =
3567 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3568 std::tie(sgprs[2], sgprs[3]) =
3569 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3570
3571 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3572 sgprs[4] =
3573 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3574 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3575 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3576 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3577 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3578 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3579
3580 IntegerType i32 = rewriter.getI32Type();
3581 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3582 assert(v8i32 && "expected type conversion to succeed");
3583 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3584
3585 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3586 dgroup1 =
3587 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3588 }
3589
3590 return dgroup1;
3591 }
3592
3593 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3594 ConversionPatternRewriter &rewriter, Location loc,
3595 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3596 int64_t offset) const {
3597 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3598 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3599 SmallVector<OpFoldResult> mixedGlobalSizes =
3600 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3601 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3602 return sgpr0;
3603
3604 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3605 Value tensorDimX;
3606 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3607 tensorDimX =
3608 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3609 } else {
3610 IntegerType i32 = rewriter.getI32Type();
3611 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3612 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3613 }
3614
3615 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3616 }
3617
3618 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3619 ConversionPatternRewriter &rewriter, Location loc,
3620 Value sgpr0, ArrayRef<Value> consts) const {
3621 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3622 }
3623
3624 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3625 Location loc, Value accumulator,
3626 Value value, int64_t shift) const {
3627
3628 IntegerType i32 = rewriter.getI32Type();
3629 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3630 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3631 }
3632
3633 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3634 ConversionPatternRewriter &rewriter, Location loc,
3635 Value sgpr1, ArrayRef<Value> consts,
3636 int64_t offset) const {
3637 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3638 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3639 }
3640
3641 std::pair<Value, Value>
3642 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3643 ConversionPatternRewriter &rewriter, Location loc,
3644 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3645 int64_t offset) const {
3646 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3647 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3648 globalAddrIncrement, offset);
3649 Value shift = createI64Constant(rewriter, loc, 32);
3650 globalAddrIncrement =
3651 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3652 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3653 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3654 globalAddrIncrement, offset + 32);
3655 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3656 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3657 return {sgpr2, sgpr3};
3658 }
3659
3660 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3661 ConversionPatternRewriter &rewriter,
3662 Location loc, Value sgpr1,
3663 ArrayRef<Value> consts) const {
3664 Value ldsIncrement = op.getLdsIncrement();
3665 constexpr int64_t dim = 3;
3666 constexpr int64_t offset = 32;
3667 if (!ldsIncrement)
3668 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3669 offset);
3670 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3671 offset);
3672 }
3673
3674 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3675 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3676 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3677 Value globalIncrement = op.getGlobalIncrement();
3678 constexpr int32_t dim = 2;
3679 constexpr int32_t offset = 64;
3680 if (!globalIncrement)
3681 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3682 consts, dim, offset);
3683 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3684 consts, offset);
3685 }
3686
3687 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3688 ConversionPatternRewriter &rewriter, Location loc,
3689 Value sgpr3, ArrayRef<Value> consts,
3690 int32_t offset) const {
3691 Value iterationCount = adaptor.getIterationCount();
3692 IntegerType i32 = rewriter.getI32Type();
3693 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3694 // TODO: validation if the value breaks the pre-condition.
3695 // If the pre-condition fails, there is a possibility of
3696 // affecting the higher bits. In a following PR implement
3697 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3698 // checked at runtime.
3699 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3700 iterationCount =
3701 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3702 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3703 }
3704
3705 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3706 ConversionPatternRewriter &rewriter,
3707 Location loc, Value sgpr3,
3708 ArrayRef<Value> consts) const {
3709 Value iterateCount = op.getIterationCount();
3710 constexpr int32_t dim = 2;
3711 constexpr int32_t offset = 112;
3712 if (!iterateCount)
3713 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3714 offset);
3715
3716 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3717 }
3718
3719 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3720 ConversionPatternRewriter &rewriter, Location loc,
3721 ArrayRef<Value> consts) const {
3722 if constexpr (DescriptorOp::isGather())
3723 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3724 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3725 }
3726
3727 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3728 ConversionPatternRewriter &rewriter, Location loc,
3729 ArrayRef<Value> consts) const {
3730 IntegerType i32 = rewriter.getI32Type();
3731 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3732 assert(v4i32 && "expected type conversion to succeed.");
3733
3734 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3735 if (onlyNeedsTwoDescriptors)
3736 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3737
3738 constexpr int64_t sgprlen = 4;
3739 Value sgprs[sgprlen];
3740 for (int i = 0; i < sgprlen; ++i)
3741 sgprs[i] = consts[0];
3742
3743 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3744 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3745 sgprs[1], consts);
3746 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3747 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3748 sgprs[3] =
3749 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3750
3751 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3752 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3753 dgroup2 =
3754 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3755
3756 return dgroup2;
3757 }
3758
3759 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3760 ConversionPatternRewriter &rewriter, Location loc,
3761 ArrayRef<Value> consts, bool firstHalf) const {
3762 IntegerType i32 = rewriter.getI32Type();
3763 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3764 assert(v4i32 && "expected type conversion to succeed.");
3765
3766 Value indices = adaptor.getIndices();
3767 auto vectorType = cast<VectorType>(indices.getType());
3768 unsigned length = vectorType.getShape().back();
3769 Type elementType = vectorType.getElementType();
3770 unsigned maxLength = elementType == i32 ? 4 : 8;
3771 int32_t offset = firstHalf ? 0 : maxLength;
3772 unsigned discountedLength =
3773 std::max(static_cast<int32_t>(length - offset), 0);
3774
3775 unsigned targetSize = std::min(maxLength, discountedLength);
3776
3777 SmallVector<Value> indicesVector;
3778 for (unsigned i = offset; i < targetSize + offset; ++i) {
3779 Value idx;
3780 if (i < consts.size())
3781 idx = consts[i];
3782 else
3783 idx = createI32Constant(rewriter, loc, i);
3784 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
3785 indicesVector.push_back(elem);
3786 }
3787
3788 SmallVector<Value> indicesI32Vector;
3789 if (elementType == i32) {
3790 indicesI32Vector = indicesVector;
3791 } else {
3792 for (unsigned i = 0; i < targetSize; ++i) {
3793 Value index = indicesVector[i];
3794 indicesI32Vector.push_back(
3795 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3796 }
3797 if ((targetSize % 2) != 0)
3798 // Add padding when not divisible by two.
3799 indicesI32Vector.push_back(consts[0]);
3800 }
3801
3802 SmallVector<Value> indicesToInsert;
3803 if (elementType == i32) {
3804 indicesToInsert = indicesI32Vector;
3805 } else {
3806 unsigned size = indicesI32Vector.size() / 2;
3807 for (unsigned i = 0; i < size; ++i) {
3808 Value first = indicesI32Vector[2 * i];
3809 Value second = indicesI32Vector[2 * i + 1];
3810 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3811 indicesToInsert.push_back(joined);
3812 }
3813 }
3814
3815 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3816 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3817 dgroup =
3818 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3819
3820 return dgroup;
3821 }
3822
3823 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3824 ConversionPatternRewriter &rewriter, Location loc,
3825 ArrayRef<Value> consts) const {
3826 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
3827 }
3828
3829 std::pair<Value, Value>
3830 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3831 ConversionPatternRewriter &rewriter, Location loc,
3832 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
3833 constexpr int32_t dim = 3;
3834 constexpr int32_t offset = 0;
3835 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3836 dim, offset);
3837 }
3838
3839 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3840 ConversionPatternRewriter &rewriter,
3841 Location loc, Value sgpr1, Value sgpr2,
3842 ArrayRef<Value> consts) const {
3843 constexpr int32_t dim = 4;
3844 constexpr int32_t offset = 48;
3845 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3846 offset);
3847 }
3848
3849 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3850 ConversionPatternRewriter &rewriter, Location loc,
3851 Value sgpr2, ArrayRef<Value> consts) const {
3852 constexpr int32_t dim = 4;
3853 constexpr int32_t offset = 80;
3854 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3855 }
3856
3857 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3858 ConversionPatternRewriter &rewriter, Location loc,
3859 ArrayRef<Value> consts) const {
3860 if constexpr (DescriptorOp::isGather())
3861 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3862 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3863 }
3864
3865 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3866 ConversionPatternRewriter &rewriter, Location loc,
3867 ArrayRef<Value> consts) const {
3868 IntegerType i32 = rewriter.getI32Type();
3869 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3870 assert(v4i32 && "expected type conversion to succeed.");
3871 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3872 if (onlyNeedsTwoDescriptors)
3873 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3874
3875 constexpr int32_t sgprlen = 4;
3876 Value sgprs[sgprlen];
3877 for (int i = 0; i < sgprlen; ++i)
3878 sgprs[i] = consts[0];
3879
3880 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3881 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3882 std::tie(sgprs[1], sgprs[2]) =
3883 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3884 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3885
3886 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3887 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3888 dgroup3 =
3889 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3890
3891 return dgroup3;
3892 }
3893
3894 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3895 ConversionPatternRewriter &rewriter, Location loc,
3896 ArrayRef<Value> consts) const {
3897 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
3898 }
3899
3900 LogicalResult
3901 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3902 ConversionPatternRewriter &rewriter) const override {
3903 if (chipset < kGfx1250)
3904 return op->emitOpError(
3905 "make_dma_descriptor is only supported on gfx1250");
3906
3907 Location loc = op.getLoc();
3908
3909 SmallVector<Value> consts;
3910 for (int64_t i = 0; i < 8; ++i)
3911 consts.push_back(createI32Constant(rewriter, loc, i));
3912
3913 Value dgroup0 = this->getDGroup0(adaptor);
3914 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3915 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3916 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3917 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3918 rewriter.replaceOpWithMultiple(op, {results});
3919 return success();
3920 }
3921};
3922
3923template <typename SourceOp, typename TargetOp>
3924struct AMDGPUTensorLoadStoreOpLowering
3925 : public ConvertOpToLLVMPattern<SourceOp> {
3926 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3928 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
3929 Chipset chipset)
3930 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3931 Chipset chipset;
3932
3933 LogicalResult
3934 matchAndRewrite(SourceOp op, Adaptor adaptor,
3935 ConversionPatternRewriter &rewriter) const override {
3936 if (chipset < kGfx1250)
3937 return op->emitOpError("is only supported on gfx1250");
3938
3939 ValueRange desc = adaptor.getDesc();
3940 // Create a <v8 x i32> 0 as the fifth argument to match llvm intrinsic. It
3941 // will move into the TDM descriptor once it becomes relevant for future use
3942 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
3943 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
3944 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3945 desc[3], dgroup4, /*cachePolicy=*/0,
3946 /*alias_scopes=*/nullptr,
3947 /*noalias_scopes=*/nullptr,
3948 /*tbaa=*/nullptr);
3949 return success();
3950 }
3951};
3952
3953struct ConvertAMDGPUToROCDLPass
3954 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3955 using Base::Base;
3956
3957 void runOnOperation() override {
3958 MLIRContext *ctx = &getContext();
3959 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
3960 if (failed(maybeChipset)) {
3961 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
3962 return signalPassFailure();
3963 }
3964
3965 RewritePatternSet patterns(ctx);
3966 LLVMTypeConverter converter(ctx);
3967
3968 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
3969 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3970 LLVMConversionTarget target(getContext());
3971 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3972 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3973 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3974 if (failed(applyPartialConversion(getOperation(), target,
3975 std::move(patterns))))
3976 signalPassFailure();
3977 }
3978};
3979} // namespace
3980
3982 TypeConverter &typeConverter) {
3984 typeConverter, [](gpu::AddressSpace space) {
3985 switch (space) {
3986 case gpu::AddressSpace::Global:
3987 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3988 case gpu::AddressSpace::Workgroup:
3989 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3990 case gpu::AddressSpace::Private:
3991 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3992 }
3993 llvm_unreachable("unknown address space enum value");
3994 });
3995}
3996
3998 TypeConverter &typeConverter) {
3999 typeConverter.addTypeAttributeConversion(
4000 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
4001 -> TypeConverter::AttributeConversionResult {
4002 MLIRContext *ctx = as.getContext();
4003 Type i64 = IntegerType::get(ctx, 64);
4004 switch (as.getValue()) {
4005 case amdgpu::AddressSpace::FatRawBuffer:
4006 return IntegerAttr::get(i64, 7);
4007 case amdgpu::AddressSpace::BufferRsrc:
4008 return IntegerAttr::get(i64, 8);
4009 case amdgpu::AddressSpace::FatStructuredBuffer:
4010 return IntegerAttr::get(i64, 9);
4011 }
4012 return TypeConverter::AttributeConversionResult::abort();
4013 });
4014 typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
4015 return IntegerType::get(type.getContext(), 64);
4016 });
4017 typeConverter.addConversion([&](TDMBaseType type) -> Type {
4018 Type i32 = IntegerType::get(type.getContext(), 32);
4019 return typeConverter.convertType(VectorType::get(4, i32));
4020 });
4021 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
4022 Type i32 = IntegerType::get(type.getContext(), 32);
4023 return typeConverter.convertType(VectorType::get(4, i32));
4024 });
4025 typeConverter.addConversion(
4026 [&](TDMDescriptorType type,
4027 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
4028 Type i32 = IntegerType::get(type.getContext(), 32);
4029 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4030 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4031 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
4032 return success();
4033 });
4034
4035 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
4036 ValueRange inputs,
4038 // Only create unrealized_conversion_cast for TDMDescriptorType.
4039 // All other types which are not expected, should be
4040 // materialized by other target materialization functions.
4041 if (inputs.size() != 1)
4042 return {};
4043
4044 if (!isa<TDMDescriptorType>(inputs[0].getType()))
4045 return {};
4046
4047 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4048 return cast.getResults();
4049 };
4050
4051 typeConverter.addTargetMaterialization(addUnrealizedCast);
4052}
4053
4055 RewritePatternSet &patterns,
4056 Chipset chipset) {
4058 patterns
4059 .add<FatRawBufferCastLowering,
4060 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4061 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4062 RawBufferOpLowering<RawBufferAtomicFaddOp,
4063 ROCDL::RawPtrBufferAtomicFaddOp>,
4064 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4065 ROCDL::RawPtrBufferAtomicFmaxOp>,
4066 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4067 ROCDL::RawPtrBufferAtomicSmaxOp>,
4068 RawBufferOpLowering<RawBufferAtomicUminOp,
4069 ROCDL::RawPtrBufferAtomicUminOp>,
4070 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4071 ROCDL::RawPtrBufferAtomicCmpSwap>,
4072 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4073 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4074 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4075 SparseWMMAOpLowering, ExtPackedFp8OpLowering,
4076 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4077 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4078 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4079 TransposeLoadOpLowering, AMDGPUPermlaneLowering,
4080 AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4081 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4082 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4083 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4084 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4085 ROCDL::TensorLoadToLDSOp>,
4086 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4087 ROCDL::TensorStoreFromLDSOp>,
4088 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4089 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering>(converter,
4090 chipset);
4091 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4092 DsBarrierStatePendingCountOpLowering,
4093 DsBarrierStateInitCountOpLowering,
4094 DsBarrierStatePhaseParityLowering>(converter);
4095}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL types.
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:229
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14