MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Matchers.h"
25#include "mlir/Pass/Pass.h"
26
28
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
33#include <optional>
34
35namespace mlir {
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
38} // namespace mlir
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43// Define commonly used chipsets versions for convenience.
44constexpr Chipset kGfx908 = Chipset(9, 0, 8);
45constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
46constexpr Chipset kGfx942 = Chipset(9, 4, 2);
47constexpr Chipset kGfx950 = Chipset(9, 5, 0);
48constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
49
50/// Convert an unsigned number `val` to i32.
51static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
52 Location loc, Value val) {
53 IntegerType i32 = rewriter.getI32Type();
54 // Force check that `val` is of int type.
55 auto valTy = cast<IntegerType>(val.getType());
56 if (i32 == valTy)
57 return val;
58 return valTy.getWidth() > 32
59 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61}
62
63static Value createI32Constant(ConversionPatternRewriter &rewriter,
64 Location loc, int32_t value) {
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
66}
67
68/// Convert an unsigned number `val` to i64.
69static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
70 Location loc, Value val) {
71 IntegerType i64 = rewriter.getI64Type();
72 // Force check that `val` is of int type.
73 auto valTy = cast<IntegerType>(val.getType());
74 if (i64 == valTy)
75 return val;
76 return valTy.getWidth() > 64
77 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
79}
80
81static Value createI64Constant(ConversionPatternRewriter &rewriter,
82 Location loc, int64_t value) {
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
84}
85
86/// Returns the linear index used to access an element in the memref.
87static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
88 Location loc, MemRefDescriptor &memRefDescriptor,
90 IntegerType i32 = rewriter.getI32Type();
92 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
93 if (stride != 1) { // Skip if stride is 1.
94 Value strideValue =
95 ShapedType::isDynamic(stride)
96 ? convertUnsignedToI32(rewriter, loc,
97 memRefDescriptor.stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
100 }
101 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
102 : increment;
103 }
104 return index ? index : createI32Constant(rewriter, loc, 0);
105}
106
107/// Compute the contents of the `num_records` field for a given memref
108/// descriptor - that is, the number of bytes that's one element past the
109/// greatest possible valid index into the memref.
110static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
111 MemRefType memrefType,
112 MemRefDescriptor &memrefDescriptor,
113 ArrayRef<int64_t> strides, int64_t elementByteWidth,
114 amdgpu::Chipset chipset, bool boundsCheck) {
115 if (chipset >= kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
117 return createI64Constant(rewriter, loc, first45bits);
118 }
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
122 ArrayRef<int64_t> shape = memrefType.getShape();
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(shape[i] * strides[i], size);
125 size = size * elementByteWidth;
126 return createI64Constant(rewriter, loc, size);
127 }
128 Value maxIndex;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.size(rewriter, loc, i);
131 Value stride = memrefDescriptor.stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
133 maxIndex = maxIndex
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
135 : maxThisDim;
136 }
137 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
138 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
140}
141
142static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
143 Value basePointer, Value numRecords,
144 bool boundsCheck, amdgpu::Chipset chipset,
145 Value cacheSwizzleStride = nullptr,
146 unsigned addressSpace = 8) {
147 // The stride value is generally 0. However, on MI-300 and onward, you can
148 // enable a cache swizzling mode by setting bit 14 of the stride field
149 // and setting that stride to a cache stride.
150 Type i16 = rewriter.getI16Type();
151 Value stride;
152 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
158 /*isDisjoint=*/true);
159 } else {
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
162 }
163
164 uint32_t flags = 0;
165 if (chipset >= kGfx1250) {
166 // Flag word:
167 // bit 0: swizzle
168 // bit 1: 0 means (total_offset + payload > numRecords)
169 // 1 means ((total_offset + payload >) numRecords) || ((offset +
170 // payload) > stride) only applied when swizzle_enable = 0. keep at
171 // zero.
172 // whether oob is done depends on numRecords.
173 // bits 2-3: Type (must be 0)
174 } else {
175 // Get the number of elements.
176 // Flag word:
177 // bits 0-11: dst sel, ignored by these intrinsics
178 // bits 12-14: data format (ignored, must be nonzero, 7=float)
179 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
180 // bit 19: In nested heap (0 here)
181 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
182 // bits 21-22: Index stride for swizzles (N/A)
183 // bit 23: Add thread ID (0)
184 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
185 // bits 25-26: Reserved (0)
186 // bit 27: Buffer is non-volatile (CDNA only)
187 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
188 // none, 3 = either swizzles or testing against offset field) RDNA only
189 // bits 30-31: Type (must be 0)
190 flags |= (7 << 12) | (4 << 15);
191 if (chipset.majorVersion >= 10) {
192 flags |= (1 << 24);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
195 }
196 }
197 Value flagsConst = createI32Constant(rewriter, loc, flags);
198 Type rsrcType =
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
202 return resource;
203}
204
205namespace {
206struct FatRawBufferCastLowering
207 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
208 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
210 chipset(chipset) {}
211
212 Chipset chipset;
213
214 LogicalResult
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
221 MemRefDescriptor descriptor(memRef);
222
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
225 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
226
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError("Can't lower non-stride-offset memrefs");
231
232 Value numRecords = adaptor.getValidBytes();
233 if (!numRecords)
234 numRecords =
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
237
238 Value basePointer =
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
241 memrefType)
242 : descriptor.alignedPtr(rewriter, loc);
243
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
248
249 bool hasSizes = memrefType.getRank() > 0;
250 // No need to unpack() and pack() all the individual sizes and strides,
251 // so we'll just extract the arrays.
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
255 : Value{};
256 Value strides =
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
259 : Value{};
260
261 Value fatPtr = makeBufferRsrc(
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
264
265 Value result = MemRefDescriptor::poison(
266 rewriter, loc,
267 getTypeConverter()->convertType(op.getResult().getType()));
268 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
269 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
274 if (hasSizes) {
275 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
279 }
280 rewriter.replaceOp(op, result);
281 return success();
282 }
283};
284
285/// Define lowering patterns for raw buffer ops
286template <typename GpuOp, typename Intrinsic>
287struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
288 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
290
291 Chipset chipset;
292 static constexpr uint32_t maxVectorOpWidth = 128;
293
294 LogicalResult
295 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
301
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
304
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref) // no write component to this op
307 storeData = Value();
308 Type wantedDataType;
309 if (storeData)
310 wantedDataType = storeData.getType();
311 else
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
313
314 Value atomicCmpData = Value();
315 // Operand index 1 of a load is the indices, trying to read them can crash.
316 if (storeData) {
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
320 }
321
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
323
324 Type i32 = rewriter.getI32Type();
325
326 // Get the type size in bytes.
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
329 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
330 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
331
332 // If we want to load a vector<NxT> with total size <= 32
333 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
334 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
335 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
336 // so bitcast any floats to integers.
337 Type llvmBufferValType = llvmWantedDataType;
338 if (atomicCmpData) {
339 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
342 }
343 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
345 uint32_t elemBits =
346 dataLayout.getTypeSizeInBits(dataVector.getElementType());
347 uint32_t totalBits = elemBits * vecLen;
348 bool usePackedFp16 =
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) + " bits, but we call for " +
354 Twine(totalBits) +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError("Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
363 } else {
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
366 }
367 }
368 }
369 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
370 // Buffer intrinsics doesn't support 1-element vectors, cast them to
371 // scalars.
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
374 }
375
376 SmallVector<Value, 6> args;
377 if (storeData) {
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
382 } else {
383 args.push_back(storeData);
384 }
385 }
386
387 if (atomicCmpData) {
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
392 } else {
393 args.push_back(atomicCmpData);
394 }
395 }
396
397 // Construct buffer descriptor from memref, attributes
398 int64_t offset = 0;
399 SmallVector<int64_t, 5> strides;
400 if (failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
402
403 MemRefDescriptor memrefDescriptor(memref);
404
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
407 Value numRecords =
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
410 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
413
414 // Indexing (voffset)
415 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
419 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
421 extraOffsetConst)
422 : extraOffsetConst;
423 }
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
426
427 // SGPR offset.
428 Value sgprOffset = adaptor.getSgprOffset();
429 if (!sgprOffset)
430 sgprOffset = createI32Constant(rewriter, loc, 0);
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
433
434 // bit 0: GLC = 0 (atomics drop value, less coherency)
435 // bits 1-2: SLC, DLC = 0 (similarly)
436 // bit 3: swizzled (0 for raw)
437 args.push_back(createI32Constant(rewriter, loc, 0));
438
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
440 llvmBufferValType);
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
443 if (lowered->getNumResults() == 1) {
444 Value replacement = lowered->getResult(0);
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
448 }
449 rewriter.replaceOp(gpuOp, replacement);
450 } else {
451 rewriter.eraseOp(gpuOp);
452 }
453 return success();
454 }
455};
456
457// TODO: AMDGPU backend already have all this bitpacking logic, we should move
458// it to some common place.
459/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
460/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
461/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
462/// Vmcnt = Waitcnt[15:10] (gfx11)
463/// Expcnt = Waitcnt[6:4] (pre-gfx11)
464/// Expcnt = Waitcnt[2:0] (gfx11)
465/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
466/// Lgkmcnt = Waitcnt[13:8] (gfx10)
467/// Lgkmcnt = Waitcnt[9:4] (gfx11)
468static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
469 unsigned expcnt, unsigned lgkmcnt) {
470 if (chipset.majorVersion < 9) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
475 }
476 if (chipset.majorVersion == 9) {
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
484 }
485 if (chipset.majorVersion == 10) {
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
493 }
494 if (chipset.majorVersion == 11) {
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
499 }
500 return failure();
501}
502
503struct MemoryCounterWaitOpLowering
504 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
505 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
506 Chipset chipset)
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
508 chipset(chipset) {}
509
510 Chipset chipset;
511
512 LogicalResult
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
519
520 if (std::optional<int> load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
522
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
525
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
528
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
531
532 rewriter.eraseOp(op);
533 return success();
534 }
535
536 if (adaptor.getTensor())
537 return op.emitOpError("unsupported chipset");
538
539 auto getVal = [](Attribute attr) -> unsigned {
540 if (attr)
541 return cast<IntegerAttr>(attr).getInt();
542
543 // This value will be clamped to the maximum value for the chipset.
544 return 1024;
545 };
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
548
549 unsigned vmcnt = 1024;
550 Attribute load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
552 if (load && store) {
553 vmcnt = getVal(load) + getVal(store);
554 } else if (load) {
555 vmcnt = getVal(load);
556 } else if (store) {
557 vmcnt = getVal(store);
558 }
559
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
561 if (failed(waitcnt))
562 return op.emitOpError("unsupported chipset");
563
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
565 return success();
566 }
567};
568
569struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
570 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
572
573 Chipset chipset;
574
575 LogicalResult
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter) const override {
578 Location loc = op.getLoc();
579 // This ensures that waits on global memory aren't introduced on
580 // chips that don't have the BackOffBarrier feature enabled in LLVM.
581 bool requiresInlineAsm = chipset < kGfx90a;
582
583 Attribute mmra =
584 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
585 // Note: while there *is* a workgroup-one-as scope, this, when combined with
586 // the MMRA, will lead to the fence having no effect. This is because the
587 // codepaths for an atomic load or store will observe that a
588 // one-address-space atomic to LDS requires no synchronization because
589 // operations on LDS are totally ordered with respect to each other, and so
590 // will not emit the correct waitcnt operations that these fences are
591 // intended to produce. Therefore, we use a broader type of fence and rely
592 // on the MMRA to relax it to the semantics we want.
593 StringRef scope = "workgroup";
594
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints = "";
603 LLVM::InlineAsmOp::create(
604 rewriter, loc,
605 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
606 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
607 /*is_align_stack=*/false, LLVM::TailCallKind::None,
608 /*asm_dialect=*/asmDialectAttr,
609 /*operand_attrs=*/ArrayAttr());
610 } else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
612 } else {
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
615 }
616
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
621 return success();
622 }
623};
624
625struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
626 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
628
629 Chipset chipset;
630
631 LogicalResult
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
636 return success();
637 }
638};
639
640} // namespace
641
642/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
643/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
644///
645/// Specifically:
646/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
647/// allows bf16. Newer MFMAs support bf16 types on operand, check
648/// IntrinsicsAMDGPU.td file for reference.
649/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
650/// instead, which is what the f8f6f4 intrinsics use.
651/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
652/// integer.
653///
654/// Note that the type of `input` has already been LLVM type converted:
655/// therefore 8-bit and smaller floats are represented as their corresponding
656/// `iN` integers.
657static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
658 Location loc, Value input,
659 bool allowBf16 = true) {
660 Type inputType = input.getType();
661 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
668 rewriter, loc,
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
674 32);
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
677 input);
678 }
679 }
680 return input;
681}
682
683/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
684static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
685 Location loc, Value input,
686 bool allowBf16 = true) {
687 Type inputType = input.getType();
688 auto vectorType = cast<VectorType>(inputType);
689 // bf16 -> i16 when not allowed (pre-gfx950).
690 if (vectorType.getElementType().isBF16() && !allowBf16)
691 return LLVM::BitcastOp::create(
692 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
693 // i8/fp8 vectors -> vector<Nxi32>.
694 if (isa<IntegerType>(vectorType.getElementType()) &&
695 vectorType.getElementTypeBitWidth() <= 8) {
696 int64_t numWords = llvm::divideCeil(
697 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
698 return LLVM::BitcastOp::create(
699 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
700 }
701 return input;
702}
703
704/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
705/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
706///
707/// Specifically:
708/// 1. If `input` is a i8 value, zero extend it to i32
709/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
710///
711/// Note that the type of `input` has already been LLVM type converted:
712/// therefore 8-bit and smaller floats are represented as their corresponding
713/// `iN` integers.
714static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
715 Value input) {
716 return TypeSwitch<Type, Value>(input.getType())
717 .Case([&](IntegerType) {
718 // Handle scalar i8: zero extend to i32.
719 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
720 input);
721 })
722 .Case([&](VectorType vectorType) {
723 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
724 int64_t numElements = vectorType.getNumElements();
725 assert((numElements == 4 || numElements == 8) &&
726 "scale operand must be a vector of length 4 or 8");
727 IntegerType outputType =
728 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
729 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
730 })
731 .DefaultUnreachable("unexpected input type for scale operand");
732}
733
734/// Maps f8 scale element types to WMMA scale format codes.
735static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
737 .Case([](Float8E8M0FNUType) { return 0; })
738 .Case([](Float8E4M3FNType) { return 2; })
739 .Default(std::nullopt);
740}
741
742/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
743/// and scale block size (16 or 32).
744static std::optional<StringRef>
746 if (m == 16 && n == 16 && k == 128)
747 return isScale16
748 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
749 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
750
751 if (m == 32 && n == 16 && k == 128)
752 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
753 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
754
755 return std::nullopt;
756}
757
758/// Push an input operand. If it is a float type, nothing to do. If it is
759/// an integer type, then we need to also push its signdness (1 for signed, 0
760/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
761/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
762/// We also need to convert bfloat inputs to i16 to account for the bfloat
763/// intrinsics having been defined before the AMD backend supported bfloat. We
764/// similarly need to pack 8-bit float types into integers as if they were i8
765/// (which they are for the backend's purposes).
767 ConversionPatternRewriter &rewriter, Location loc,
768 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
769 Value mlirInput, SmallVectorImpl<Value> &operands,
770 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
771 Type inputType = llvmInput.getType();
772 auto vectorType = dyn_cast<VectorType>(inputType);
773 if (!vectorType) {
774 operands.push_back(llvmInput);
775 return;
776 }
777 Type elemType = vectorType.getElementType();
778 if (elemType.getIntOrFloatBitWidth() > 8) {
779 operands.push_back(llvmInput);
780 return;
781 }
782
783 // We need to check the type of the input before conversion to properly test
784 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
785 // fp8/int8 information is lost during the conversion process.
786 auto mlirInputType = cast<VectorType>(mlirInput.getType());
787 bool isInputInteger = mlirInputType.getElementType().isInteger();
788 if (isInputInteger) {
789 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
790 bool localIsUnsigned = isUnsigned;
791 if (elemType.isUnsignedInteger()) {
792 localIsUnsigned = true;
793 } else if (elemType.isSignedInteger()) {
794 localIsUnsigned = false;
795 }
796 attrs.push_back(
797 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
798 }
799
800 int64_t numBits =
801 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
802 Type i32 = rewriter.getI32Type();
803 Type intrinsicInType = numBits <= 32
804 ? (Type)rewriter.getIntegerType(numBits)
805 : (Type)VectorType::get(numBits / 32, i32);
806 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
807 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
808 loc, llvmIntrinsicInType, llvmInput);
809 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
810 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
811 // Add in the zeros here.
812 if (numBits < 32)
813 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
814 operands.push_back(castInput);
815}
816
817/// Push the output operand. For many cases this is only pushing the output in
818/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
819/// since the same numbers of VGPRs is used, we need to decide if to store the
820/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
821/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
822/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
823/// as the instructions have been changed to return fewer registers instead.
824static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
825 Location loc,
826 const TypeConverter *typeConverter,
827 Value output, int32_t subwordOffset,
828 bool clamp, SmallVectorImpl<Value> &operands,
830 Type inputType = output.getType();
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 Type elemType = vectorType.getElementType();
833 operands.push_back(output);
834 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
835 attrs.push_back(
836 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
837 } else if (elemType.isInteger(32)) {
838 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
839 }
840}
841
842/// Return true if `type` is the E5M2 variant of an 8-bit float that is
843/// supported by the `_bf8` instructions on the given `chipset`.
844static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
845 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
846 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
847}
848
849/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
850/// supported by the `_fp8` instructions on the given `chipset`.
851static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
852 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
853 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
854}
855
856/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
857/// if one exists. This includes checking to ensure the intrinsic is supported
858/// on the architecture you are compiling for.
859static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
860 Chipset chipset) {
861 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
862 b = mfma.getBlocks();
863 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
864 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
865
866 if (sourceElem.isF32() && destElem.isF32()) {
867 if (mfma.getReducePrecision() && chipset >= kGfx942) {
868 if (m == 32 && n == 32 && k == 4 && b == 1)
869 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
870 if (m == 16 && n == 16 && k == 8 && b == 1)
871 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
872 }
873 if (m == 32 && n == 32 && k == 1 && b == 2)
874 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
875 if (m == 16 && n == 16 && k == 1 && b == 4)
876 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
877 if (m == 4 && n == 4 && k == 1 && b == 16)
878 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
879 if (m == 32 && n == 32 && k == 2 && b == 1)
880 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
881 if (m == 16 && n == 16 && k == 4 && b == 1)
882 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
883 }
884
885 if (sourceElem.isF16() && destElem.isF32()) {
886 if (chipset >= kGfx950) {
887 if (m == 32 && n == 32 && k == 16 && b == 1)
888 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 && b == 1)
890 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
891 }
892 if (m == 32 && n == 32 && k == 4 && b == 2)
893 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
894 if (m == 16 && n == 16 && k == 4 && b == 4)
895 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
896 if (m == 4 && n == 4 && k == 4 && b == 16)
897 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
898 if (m == 32 && n == 32 && k == 8 && b == 1)
899 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
900 if (m == 16 && n == 16 && k == 16 && b == 1)
901 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
902 }
903
904 if (sourceElem.isBF16() && destElem.isF32()) {
905 if (chipset >= kGfx950) {
906 if (m == 32 && n == 32 && k == 16 && b == 1)
907 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
908 if (m == 16 && n == 16 && k == 32 && b == 1)
909 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
910 }
911 if (chipset >= kGfx90a) {
912 if (m == 32 && n == 32 && k == 4 && b == 2)
913 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
914 if (m == 16 && n == 16 && k == 4 && b == 4)
915 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
916 if (m == 4 && n == 4 && k == 4 && b == 16)
917 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
918 if (m == 32 && n == 32 && k == 8 && b == 1)
919 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 16 && b == 1)
921 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
922 }
923 if (m == 32 && n == 32 && k == 2 && b == 2)
924 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
925 if (m == 16 && n == 16 && k == 2 && b == 4)
926 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
927 if (m == 4 && n == 4 && k == 2 && b == 16)
928 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
929 if (m == 32 && n == 32 && k == 4 && b == 1)
930 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 8 && b == 1)
932 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
933 }
934
935 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
936 if (chipset >= kGfx950) {
937 if (m == 32 && n == 32 && k == 32 && b == 1)
938 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
939 if (m == 16 && n == 16 && k == 64 && b == 1)
940 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
941 }
942 if (m == 32 && n == 32 && k == 4 && b == 2)
943 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
944 if (m == 16 && n == 16 && k == 4 && b == 4)
945 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
946 if (m == 4 && n == 4 && k == 4 && b == 16)
947 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
948 if (m == 32 && n == 32 && k == 8 && b == 1)
949 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
950 if (m == 16 && n == 16 && k == 16 && b == 1)
951 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
952 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
953 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
954 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
955 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
956 }
957
958 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
959 if (m == 16 && n == 16 && k == 4 && b == 1)
960 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
961 if (m == 4 && n == 4 && k == 4 && b == 4)
962 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
963 }
964
965 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
966 // Known to be correct because there are no scalar f8 instructions and
967 // because a length mismatch will have been caught by the verifier.
968 Type sourceBElem =
969 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
970 if (m == 16 && n == 16 && k == 32 && b == 1) {
971 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
972 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
973 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
974 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
975 }
976 if (m == 32 && n == 32 && k == 16 && b == 1) {
977 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
978 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
979 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
980 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
981 }
982 }
983
984 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
985 Type sourceBElem =
986 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
987 if (m == 16 && n == 16 && k == 32 && b == 1) {
988 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
989 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
990 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
991 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
992 }
993 if (m == 32 && n == 32 && k == 16 && b == 1) {
994 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
995 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
996 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
997 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
998 }
999 }
1000
1001 return std::nullopt;
1002}
1003
1004static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1006 .Case([](Float8E4M3FNType) { return 0u; })
1007 .Case([](Float8E5M2Type) { return 1u; })
1008 .Case([](Float6E2M3FNType) { return 2u; })
1009 .Case([](Float6E3M2FNType) { return 3u; })
1010 .Case([](Float4E2M1FNType) { return 4u; })
1011 .Default(std::nullopt);
1012}
1013
1014/// If there is a scaled MFMA instruction for the input element types `aType`
1015/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1016/// blocks) on the given `chipset`, return a tuple consisting of the
1017/// OperationName of the intrinsic and the type codes that need to be passed to
1018/// that intrinsic. Note that this is also used to implement some un-scaled
1019/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1020/// MFMA with a scale of 0.
1021static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1022mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1023 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1024 aType = getElementTypeOrSelf(aType);
1025 bType = getElementTypeOrSelf(bType);
1026 destType = getElementTypeOrSelf(destType);
1027
1028 if (chipset < kGfx950)
1029 return std::nullopt;
1030 if (!isa<Float32Type>(destType))
1031 return std::nullopt;
1032
1033 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1034 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1035 if (!aTypeCode || !bTypeCode)
1036 return std::nullopt;
1037
1038 if (m == 32 && n == 32 && k == 64 && b == 1)
1039 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1040 *aTypeCode, *bTypeCode};
1041 if (m == 16 && n == 16 && k == 128 && b == 1)
1042 return std::tuple{
1043 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1044 *bTypeCode};
1045
1046 return std::nullopt;
1047}
1048
1049static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1050mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1052 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1053 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1054 mfma.getBlocks(), chipset);
1055}
1056
1057static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1058mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1059 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1060 smfma.getSourceB().getType(),
1061 smfma.getDestC().getType(), smfma.getM(),
1062 smfma.getN(), smfma.getK(), 1u, chipset);
1063}
1064
1065/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1066/// for RDNA3/4 architectures.
1067static std::optional<StringRef>
1068wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1069 Type elemDestType, uint32_t k, bool isRDNA3) {
1070 using fp8 = Float8E4M3FNType;
1071 using bf8 = Float8E5M2Type;
1072
1073 // Handle k == 16 for RDNA3/4.
1074 if (k == 16) {
1075 // Common patterns for RDNA3 and RDNA4.
1076 if (elemSourceType.isF16() && elemDestType.isF32())
1077 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1078 if (elemSourceType.isBF16() && elemDestType.isF32())
1079 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1080 if (elemSourceType.isF16() && elemDestType.isF16())
1081 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1082 if (elemSourceType.isBF16() && elemDestType.isBF16())
1083 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1084 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1085 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1086
1087 // RDNA3 specific patterns.
1088 if (isRDNA3) {
1089 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1090 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1092 }
1093
1094 // RDNA4 specific patterns (fp8/bf8).
1095 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1096 elemDestType.isF32())
1097 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1098 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1099 elemDestType.isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1101 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.isF32())
1103 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1105 elemDestType.isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1107 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1108 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1109
1110 return std::nullopt;
1111 }
1112
1113 // Handle k == 32 for RDNA4.
1114 if (k == 32 && !isRDNA3) {
1115 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1116 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1117 }
1118
1119 return std::nullopt;
1120}
1121
1122/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1123/// for the gfx1250 architecture.
1124static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1125 Type elemBSourceType,
1126 Type elemDestType,
1127 uint32_t k) {
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1130
1131 if (k == 4) {
1132 if (elemSourceType.isF32() && elemDestType.isF32())
1133 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1134
1135 return std::nullopt;
1136 }
1137
1138 if (k == 32) {
1139 if (elemSourceType.isF16() && elemDestType.isF32())
1140 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1141 if (elemSourceType.isBF16() && elemDestType.isF32())
1142 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1143 if (elemSourceType.isF16() && elemDestType.isF16())
1144 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1145 if (elemSourceType.isBF16() && elemDestType.isBF16())
1146 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1147
1148 return std::nullopt;
1149 }
1150
1151 if (k == 64) {
1152 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1153 if (elemDestType.isF32())
1154 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1155 if (elemDestType.isF16())
1156 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1157 }
1158 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1159 if (elemDestType.isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1161 if (elemDestType.isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1163 }
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.isF32())
1166 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1167 if (elemDestType.isF16())
1168 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1169 }
1170 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1171 if (elemDestType.isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1173 if (elemDestType.isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1175 }
1176 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1177 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1178
1179 return std::nullopt;
1180 }
1181
1182 if (k == 128) {
1183 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1184 if (elemDestType.isF32())
1185 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1186 if (elemDestType.isF16())
1187 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1188 }
1189 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1190 if (elemDestType.isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1192 if (elemDestType.isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1194 }
1195 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.isF32())
1197 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1198 if (elemDestType.isF16())
1199 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1200 }
1201 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1202 if (elemDestType.isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1204 if (elemDestType.isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1206 }
1207
1208 return std::nullopt;
1209 }
1210
1211 return std::nullopt;
1212}
1213
1214/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1215/// operation if one exists. This includes checking to ensure the intrinsic is
1216/// supported on the architecture you are compiling for.
1217static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1218 Chipset chipset) {
1219 bool isGfx950 = chipset >= kGfx950;
1220 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1221 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1222
1223 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1224 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1225 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1226 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1227
1228 if (m == 16 && n == 16 && k == 32) {
1229 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1230 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1231 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1232 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1233 }
1234
1235 if (m == 16 && n == 16 && k == 64) {
1236 if (isGfx950) {
1237 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1238 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1239 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1240 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1241 }
1242 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1243 destElem.isInteger(32))
1244 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1245 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1246 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1247 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1248 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1249 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1250 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1251 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1252 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1253 }
1254
1255 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1256 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1257 destElem.isInteger(32))
1258 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1259 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1260 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1261 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1262 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1263 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1264 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1265 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1266 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1267 }
1268
1269 if (m == 32 && n == 32 && k == 16) {
1270 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1271 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1272 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1273 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1274 }
1275
1276 if (m == 32 && n == 32 && k == 32) {
1277 if (isGfx950) {
1278 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1279 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1280 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1281 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1282 }
1283 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1284 destElem.isInteger(32))
1285 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1286 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1287 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1288 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1289 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1290 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1291 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1292 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1293 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1294 }
1295
1296 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1297 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1298 destElem.isInteger(32))
1299 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1300 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1301 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1302 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1303 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1304 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1305 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1306 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1307 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1308 }
1309
1310 return std::nullopt;
1311}
1312
1313/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1314/// if one exists. This includes checking to ensure the intrinsic is supported
1315/// on the architecture you are compiling for.
1316static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1317 Chipset chipset) {
1318 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1319 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1320 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1321 Type elemSourceType = sourceVectorType.getElementType();
1322 Type elemBSourceType = sourceBVectorType.getElementType();
1323 Type elemDestType = destVectorType.getElementType();
1324
1325 const uint32_t k = wmma.getK();
1326 const bool isRDNA3 = chipset.majorVersion == 11;
1327 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1328
1329 // Handle RDNA3 and RDNA4.
1330 if (isRDNA3 || isRDNA4)
1331 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1332 k, isRDNA3);
1333
1334 // Handle gfx1250.
1335 if (chipset == kGfx1250)
1336 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1337 elemDestType, k);
1338
1339 return std::nullopt;
1340}
1341
1342namespace {
1343struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1344 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1345 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1346
1347 Chipset chipset;
1348
1349 LogicalResult
1350 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1351 ConversionPatternRewriter &rewriter) const override {
1352 Location loc = op.getLoc();
1353 Type outType = typeConverter->convertType(op.getDestD().getType());
1354 Type intrinsicOutType = outType;
1355 if (auto outVecType = dyn_cast<VectorType>(outType))
1356 if (outVecType.getElementType().isBF16())
1357 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1358
1359 if (chipset.majorVersion != 9 || chipset < kGfx908)
1360 return op->emitOpError("MFMA only supported on gfx908+");
1361 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1362 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1363 if (chipset < kGfx942)
1364 return op.emitOpError("negation unsupported on older than gfx942");
1365 getBlgpField |=
1366 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1367 }
1368 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1369 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1370 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1371 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1372 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1373
1374 bool isScaled =
1375 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1376 if (isScaled &&
1377 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1378 return op.emitOpError(
1379 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1380 "be scaled as those fields are used for type information");
1381 }
1382
1383 StringRef intrinsicName =
1384 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1385 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1386 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1387 bool allowBf16 = [&]() {
1388 if (chipset < kGfx950)
1389 return false;
1390 if (isScaled)
1391 return true;
1392 return intrinsicName.contains("16x16x32.bf16") ||
1393 intrinsicName.contains("32x32x16.bf16");
1394 }();
1395 OperationState loweredOp(loc, intrinsicName);
1396 loweredOp.addTypes(intrinsicOutType);
1397 loweredOp.addOperands({packSmallFloatVectorOperand(
1398 rewriter, loc, adaptor.getSourceA(), allowBf16),
1400 rewriter, loc, adaptor.getSourceB(), allowBf16),
1401 adaptor.getDestC()});
1402 if (isScaled) {
1403 Value zero = createI32Constant(rewriter, loc, 0);
1404 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1405 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1406 loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1407 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1408 {"opselA", rewriter.getI32IntegerAttr(0)},
1409 {"opselB", rewriter.getI32IntegerAttr(0)}});
1410 } else {
1411 loweredOp.addAttributes(
1412 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1413 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1414 {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1415 };
1416 Value lowered = rewriter.create(loweredOp)->getResult(0);
1417 if (outType != intrinsicOutType)
1418 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1419 rewriter.replaceOp(op, lowered);
1420 return success();
1421 }
1422};
1423
1424struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1425 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1426 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1427
1428 Chipset chipset;
1429
1430 LogicalResult
1431 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1432 ConversionPatternRewriter &rewriter) const override {
1433 Location loc = op.getLoc();
1434 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1435
1436 if (chipset.majorVersion != 9 || chipset < kGfx950)
1437 return op->emitOpError("scaled MFMA only supported on gfx908+");
1438 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1439 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1440 if (!maybeScaledIntrinsic.has_value())
1441 return op.emitOpError(
1442 "no intrinsic matching scaled MFMA size on given chipset");
1443
1444 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1445 OperationState loweredOp(loc, intrinsicName);
1446 loweredOp.addTypes(intrinsicOutType);
1447 loweredOp.addOperands(
1448 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1449 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1450 adaptor.getDestC()});
1451 loweredOp.addOperands(
1452 {/*scales A*/
1453 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1454 /*scales B*/
1455 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1456 loweredOp.addAttributes(
1457 {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1458 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1459 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1460 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1461
1462 Value lowered = rewriter.create(loweredOp)->getResult(0);
1463 rewriter.replaceOp(op, lowered);
1464 return success();
1465 }
1466};
1467
1468struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1469 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1470 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1471
1472 Chipset chipset;
1473
1474 LogicalResult
1475 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1476 ConversionPatternRewriter &rewriter) const override {
1477 Location loc = op.getLoc();
1478 auto outType =
1479 typeConverter->convertType<VectorType>(op.getDestC().getType());
1480 if (!outType)
1481 return rewriter.notifyMatchFailure(op, "type conversion failed");
1482
1483 // smfmac is supported on gfx942 and gfx950.
1484 if (chipset.majorVersion != 9 || chipset < kGfx942)
1485 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1486 bool isGfx950 = chipset >= kGfx950;
1487
1488 Value a = convertSparseMFMAVectorOperand(rewriter, loc,
1489 adaptor.getSourceA(), isGfx950);
1490 Value b = convertSparseMFMAVectorOperand(rewriter, loc,
1491 adaptor.getSourceB(), isGfx950);
1492 Value c = adaptor.getDestC();
1493
1494 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1495 if (!maybeIntrinsic.has_value())
1496 return op.emitOpError(
1497 "no intrinsic matching sparse MFMA on the given chipset");
1498
1499 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1500 Value sparseIdx = LLVM::BitcastOp::create(
1501 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1502
1503 OperationState loweredOp(loc, maybeIntrinsic.value());
1504 loweredOp.addTypes(outType);
1505 loweredOp.addOperands({a, b, c, sparseIdx});
1506 loweredOp.addAttributes(
1507 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1508 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1509 Value lowered = rewriter.create(loweredOp)->getResult(0);
1510 rewriter.replaceOp(op, lowered);
1511 return success();
1512 }
1513};
1514
1515struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1516 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1517 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1518
1519 Chipset chipset;
1520
1521 LogicalResult
1522 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1523 ConversionPatternRewriter &rewriter) const override {
1524 Location loc = op.getLoc();
1525 auto outType =
1526 typeConverter->convertType<VectorType>(op.getDestD().getType());
1527 if (!outType)
1528 return rewriter.notifyMatchFailure(op, "type conversion failed");
1529
1530 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1531 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1532
1533 bool isGFX1250 = chipset >= kGfx1250;
1534
1535 // The WMMA operations represent vectors of bf16s as vectors of i16s
1536 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1537 // bitcast them back.
1538 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1539 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1540 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1541 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1542 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1543 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1544 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1545 VectorType rawOutType = outType;
1546 if (castOutToI16)
1547 rawOutType = outType.clone(rewriter.getI16Type());
1548 Value a = adaptor.getSourceA();
1549 if (castAToI16)
1550 a = LLVM::BitcastOp::create(rewriter, loc,
1551 aType.clone(rewriter.getI16Type()), a);
1552 Value b = adaptor.getSourceB();
1553 if (castBToI16)
1554 b = LLVM::BitcastOp::create(rewriter, loc,
1555 bType.clone(rewriter.getI16Type()), b);
1556 Value destC = adaptor.getDestC();
1557 if (castDestCToI16)
1558 destC = LLVM::BitcastOp::create(
1559 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1560
1561 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1562
1563 if (!maybeIntrinsic.has_value())
1564 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1565
1566 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1567 return op.emitOpError("subwordOffset not supported on gfx12+");
1568
1569 SmallVector<Value, 4> operands;
1570 SmallVector<NamedAttribute, 4> attrs;
1571 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1572 op.getSourceA(), operands, attrs, "signA");
1573 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1574 op.getSourceB(), operands, attrs, "signB");
1575 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1576 op.getSubwordOffset(), op.getClamp(), operands,
1577 attrs);
1578
1579 OperationState loweredOp(loc, *maybeIntrinsic);
1580 loweredOp.addTypes(rawOutType);
1581 loweredOp.addOperands(operands);
1582 loweredOp.addAttributes(attrs);
1583 Operation *lowered = rewriter.create(loweredOp);
1584
1585 Operation *maybeCastBack = lowered;
1586 if (rawOutType != outType)
1587 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1588 lowered->getResult(0));
1589 rewriter.replaceOp(op, maybeCastBack->getResults());
1590
1591 return success();
1592 }
1593};
1594
1595struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1596 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1597 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1598
1599 Chipset chipset;
1600
1601 LogicalResult
1602 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1603 ConversionPatternRewriter &rewriter) const override {
1604 Location loc = op.getLoc();
1605 auto outType =
1606 typeConverter->convertType<VectorType>(op.getDestD().getType());
1607 if (!outType)
1608 return rewriter.notifyMatchFailure(op, "type conversion failed");
1609
1610 if (chipset < kGfx1250)
1611 return op->emitOpError("WMMA scale only supported on gfx1250+");
1612
1613 int64_t m = op.getM();
1614 int64_t n = op.getN();
1615 int64_t k = op.getK();
1616
1617 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1618 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1619
1620 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1621 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1622
1623 if (!aFmtCode || !bFmtCode)
1624 return op.emitOpError("unsupported element types for scaled_wmma");
1625
1626 // Get scale vector types and determine variant (scale vs scale16).
1627 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1628 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1629
1630 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1631 return op.emitOpError("scaleA and scaleB must have equal vector length");
1632
1633 // Extract scale format from element types.
1634 Type scaleAElemType = scaleAVecType.getElementType();
1635 Type scaleBElemType = scaleBVecType.getElementType();
1636
1637 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1638 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1639
1640 if (!scaleAFmt || !scaleBFmt)
1641 return op.emitOpError("unsupported scale element types");
1642
1643 // Determine which intrinsic to use based on dimensions.
1644 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1645 std::optional<StringRef> intrinsicName =
1646 getScaledWmmaIntrinsicName(m, n, k, isScale16);
1647 if (!intrinsicName)
1648 return op.emitOpError("unsupported scaled_wmma dimensions: ")
1649 << m << "x" << n << "x" << k;
1650
1651 SmallVector<NamedAttribute, 8> attrs;
1652
1653 // The f4 variant does not have fmtA and fmtB attributes.
1654 bool is32x16 = (m == 32 && n == 16 && k == 128);
1655 if (!is32x16) {
1656 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1657 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1658 }
1659
1660 // modC uses default value of 0.
1661 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1662
1663 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1664 // half of the wave that is being selected (0 or 1).
1665 attrs.emplace_back(
1666 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1667 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1668 attrs.emplace_back(
1669 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1670 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1671
1672 // Reuse flags use default value of false.
1673 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1674 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1675
1676 // Convert typed float vectors to packed format.
1677 Value sourceA =
1678 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1679 Value sourceB =
1680 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1681
1682 // Pack scale vectors into i32/i64.
1683 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1684 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1685
1686 // Create the intrinsic call.
1687 OperationState loweredOp(loc, *intrinsicName);
1688 loweredOp.addTypes(outType);
1689 loweredOp.addOperands(
1690 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1691 loweredOp.addAttributes(attrs);
1692
1693 Operation *lowered = rewriter.create(loweredOp);
1694 rewriter.replaceOp(op, lowered->getResults());
1695
1696 return success();
1697 }
1698};
1699
1700struct TransposeLoadOpLowering
1701 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1702 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1703 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1704
1705 Chipset chipset;
1706
1707 LogicalResult
1708 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1709 ConversionPatternRewriter &rewriter) const override {
1710 if (chipset != kGfx950)
1711 return op.emitOpError("Non-gfx950 chipset not supported");
1712
1713 Location loc = op.getLoc();
1714 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1715
1716 // Elements in subbyte memrefs are stored non-contiguously,
1717 // reject if source is sub-byte memref. Use emulated memrefs instead.
1718 size_t srcElementSize =
1719 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1720 if (srcElementSize < 8)
1721 return op.emitOpError("Expect source memref to have at least 8 bits "
1722 "element size, got ")
1723 << srcElementSize;
1724
1725 auto resultType = cast<VectorType>(op.getResult().getType());
1726 Value srcPtr =
1727 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1728 (adaptor.getSrcIndices()));
1729
1730 size_t numElements = resultType.getNumElements();
1731 size_t elementTypeSize =
1732 resultType.getElementType().getIntOrFloatBitWidth();
1733
1734 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1735 // the element size is smaller than 16 bits.
1736 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1737 rewriter.getIntegerType(32));
1738 Type llvmResultType = typeConverter->convertType(resultType);
1739
1740 switch (elementTypeSize) {
1741 case 4: {
1742 assert(numElements == 16);
1743 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1744 rocdlResultType, srcPtr);
1745 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1746 break;
1747 }
1748 case 6: {
1749 assert(numElements == 16);
1750 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1751 rocdlResultType, srcPtr);
1752 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1753 break;
1754 }
1755 case 8: {
1756 assert(numElements == 8);
1757 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1758 rocdlResultType, srcPtr);
1759 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1760 break;
1761 }
1762 case 16: {
1763 assert(numElements == 4);
1764 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1765 srcPtr);
1766 break;
1767 }
1768 default:
1769 return op.emitOpError("Unsupported element size for transpose load");
1770 }
1771 return success();
1772 }
1773};
1774
1775struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1776 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1777 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1778
1779 Chipset chipset;
1780
1781 LogicalResult
1782 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1783 ConversionPatternRewriter &rewriter) const override {
1784 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1785 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1786
1787 Location loc = op.getLoc();
1788
1789 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1790 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1791
1792 // TODO: instead of only transfering one element per thread, we could
1793 // augment it to transfer multiple elements per thread by issuing multiple
1794 // `global_load_lds` instructions.
1795 Type transferType = op.getTransferType();
1796 int loadWidth = [&]() -> int {
1797 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1798 return (transferVectorType.getNumElements() *
1799 transferVectorType.getElementTypeBitWidth()) /
1800 8;
1801 }
1802 return transferType.getIntOrFloatBitWidth() / 8;
1803 }();
1804
1805 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1806 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1807 return op.emitOpError("chipset unsupported element size");
1808
1809 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1810 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1811 "16-byte load widths are only supported on gfx950");
1812
1813 Value srcPtr =
1814 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1815 (adaptor.getSrcIndices()));
1816 Value dstPtr =
1817 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1818 (adaptor.getDstIndices()));
1819
1820 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1821 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1822 /*offset=*/rewriter.getI32IntegerAttr(0),
1823 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1824 ArrayAttr{});
1825
1826 return success();
1827 }
1828};
1829
1830namespace {
1831struct ExtPackedFp8OpLowering final
1832 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1833 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1834 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1835 chipset(chipset) {}
1836 Chipset chipset;
1837
1838 LogicalResult
1839 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1840 ConversionPatternRewriter &rewriter) const override;
1841};
1842
1843struct ScaledExtPackedMatrixOpLowering final
1844 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1845 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1846 Chipset chipset)
1847 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1848 chipset(chipset) {}
1849 Chipset chipset;
1850
1851 LogicalResult
1852 matchAndRewrite(ScaledExtPackedMatrixOp op,
1853 ScaledExtPackedMatrixOpAdaptor adaptor,
1854 ConversionPatternRewriter &rewriter) const override;
1855};
1856
1857struct PackedTrunc2xFp8OpLowering final
1858 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1859 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1860 Chipset chipset)
1861 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1862 chipset(chipset) {}
1863 Chipset chipset;
1864
1865 LogicalResult
1866 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1867 ConversionPatternRewriter &rewriter) const override;
1868};
1869
1870struct PackedStochRoundFp8OpLowering final
1871 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1872 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1873 Chipset chipset)
1874 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1875 chipset(chipset) {}
1876 Chipset chipset;
1877
1878 LogicalResult
1879 matchAndRewrite(PackedStochRoundFp8Op op,
1880 PackedStochRoundFp8OpAdaptor adaptor,
1881 ConversionPatternRewriter &rewriter) const override;
1882};
1883
1884struct ScaledExtPackedOpLowering final
1885 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1886 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1887 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1888 chipset(chipset) {}
1889 Chipset chipset;
1890
1891 LogicalResult
1892 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1893 ConversionPatternRewriter &rewriter) const override;
1894};
1895
1896struct PackedScaledTruncOpLowering final
1897 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1898 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1899 Chipset chipset)
1900 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1901 chipset(chipset) {}
1902 Chipset chipset;
1903
1904 LogicalResult
1905 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1906 ConversionPatternRewriter &rewriter) const override;
1907};
1908
1909} // end namespace
1910
1911LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1912 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1913 ConversionPatternRewriter &rewriter) const {
1914 Location loc = op.getLoc();
1915 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1916 return rewriter.notifyMatchFailure(
1917 loc, "Fp8 conversion instructions are not available on target "
1918 "architecture and their emulation is not implemented");
1919 Type v4i8 =
1920 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1921 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1922 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1923
1924 Value source = adaptor.getSource();
1925 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1926 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1927 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1928 // Extend to a v4i8
1929 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1930 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1931 if (!sourceVecType) {
1932 longVec = LLVM::InsertElementOp::create(
1933 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1934 } else {
1935 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1936 Value idx = createI32Constant(rewriter, loc, i);
1937 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1938 longVec =
1939 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1940 }
1941 }
1942 source = longVec;
1943 }
1944 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1945 if (resultVecType) {
1946 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1947 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1948 op.getIndex());
1949 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1950 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1951 op.getIndex());
1952 }
1953 } else {
1954 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1955 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1956 op.getIndex());
1957 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1958 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1959 op.getIndex());
1960 }
1961 }
1962 return success();
1963}
1964
1965int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1966 int32_t firstScaleByte) {
1967 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1968 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
1969 // firstScaleByte are merged into a single attribute scaleSel. This is how
1970 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1971 // attribute but is derifed from firstScaleLane).
1972 assert(llvm::is_contained({16, 32}, blockSize));
1973 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1974
1975 const bool isFp8 = bitWidth == 8;
1976 const bool isBlock16 = blockSize == 16;
1977
1978 if (!isFp8) {
1979 int32_t bit0 = isBlock16;
1980 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1981 int32_t bit1 = (firstScaleByte == 2) << 1;
1982 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1983 int32_t bit2 = scaleWaveHalf << 2;
1984 return bit2 | bit1 | bit0;
1985 }
1986
1987 int32_t bit0 = isBlock16;
1988 // firstScaleByte is guaranteed to be defined by two bits.
1989 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1990 int32_t bits2and1 = firstScaleByte << 1;
1991 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1992 int32_t bit3 = scaleWaveHalf << 3;
1993 int32_t bits = bit3 | bits2and1 | bit0;
1994 // These are invalid cases.
1995 assert(!llvm::is_contained(
1996 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1997 return bits;
1998}
1999
2000static std::optional<StringRef>
2001scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2002 using fp4 = Float4E2M1FNType;
2003 using fp8 = Float8E4M3FNType;
2004 using bf8 = Float8E5M2Type;
2005 using fp6 = Float6E2M3FNType;
2006 using bf6 = Float6E3M2FNType;
2007 if (isa<fp4>(srcElemType)) {
2008 if (destElemType.isF16())
2009 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2010 if (destElemType.isBF16())
2011 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2012 if (destElemType.isF32())
2013 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2014 return std::nullopt;
2015 }
2016 if (isa<fp8>(srcElemType)) {
2017 if (destElemType.isF16())
2018 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2019 if (destElemType.isBF16())
2020 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2021 if (destElemType.isF32())
2022 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2023 return std::nullopt;
2024 }
2025 if (isa<bf8>(srcElemType)) {
2026 if (destElemType.isF16())
2027 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2028 if (destElemType.isBF16())
2029 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2030 if (destElemType.isF32())
2031 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2032 return std::nullopt;
2033 }
2034 if (isa<fp6>(srcElemType)) {
2035 if (destElemType.isF16())
2036 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2037 if (destElemType.isBF16())
2038 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2039 if (destElemType.isF32())
2040 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2041 return std::nullopt;
2042 }
2043 if (isa<bf6>(srcElemType)) {
2044 if (destElemType.isF16())
2045 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2046 if (destElemType.isBF16())
2047 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2048 if (destElemType.isF32())
2049 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2050 return std::nullopt;
2051 }
2052 llvm_unreachable("invalid combination of element types for packed conversion "
2053 "instructions");
2054}
2055
2056LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2057 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2058 ConversionPatternRewriter &rewriter) const {
2059 using fp4 = Float4E2M1FNType;
2060 using fp8 = Float8E4M3FNType;
2061 using bf8 = Float8E5M2Type;
2062 using fp6 = Float6E2M3FNType;
2063 using bf6 = Float6E3M2FNType;
2064 Location loc = op.getLoc();
2065 if (chipset != kGfx1250) {
2066 return rewriter.notifyMatchFailure(
2067 loc,
2068 "Scaled fp packed conversion instructions are not available on target "
2069 "architecture and their emulation is not implemented");
2070 }
2071 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2072 // is being selected.
2073 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2074 int32_t firstScaleByte = op.getFirstScaleByte();
2075 int32_t blockSize = op.getBlockSize();
2076 auto sourceType = cast<VectorType>(op.getSource().getType());
2077 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2078 unsigned bitWidth = srcElemType.getWidth();
2079
2080 auto targetType = cast<VectorType>(op.getResult().getType());
2081 auto destElemType = cast<FloatType>(targetType.getElementType());
2082
2083 IntegerType i32 = rewriter.getI32Type();
2084 Value source = adaptor.getSource();
2085 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2086 Type packedType = nullptr;
2087 if (isa<fp4>(srcElemType)) {
2088 packedType = i32;
2089 packedType = getTypeConverter()->convertType(packedType);
2090 } else if (isa<fp8, bf8>(srcElemType)) {
2091 packedType = VectorType::get(2, i32);
2092 packedType = getTypeConverter()->convertType(packedType);
2093 } else if (isa<fp6, bf6>(srcElemType)) {
2094 packedType = VectorType::get(3, i32);
2095 packedType = getTypeConverter()->convertType(packedType);
2096 } else {
2097 llvm_unreachable("invalid element type for packed scaled ext");
2098 }
2099
2100 if (!packedType || !llvmResultType) {
2101 return rewriter.notifyMatchFailure(op, "type conversion failed");
2102 }
2103
2104 std::optional<StringRef> maybeIntrinsic =
2105 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2106 if (!maybeIntrinsic.has_value())
2107 return op.emitOpError(
2108 "no intrinsic matching packed scaled conversion on the given chipset");
2109
2110 int32_t scaleSel =
2111 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2112 Value castedScale =
2113 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2114 Value castedSource =
2115 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2116
2117 OperationState loweredOp(loc, *maybeIntrinsic);
2118 loweredOp.addTypes({llvmResultType});
2119 loweredOp.addOperands({castedSource, castedScale});
2120
2121 SmallVector<NamedAttribute, 1> attrs;
2122 attrs.push_back(
2123 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2124
2125 loweredOp.addAttributes(attrs);
2126 Operation *lowered = rewriter.create(loweredOp);
2127 rewriter.replaceOp(op, lowered);
2128
2129 return success();
2130}
2131
2132LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2133 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2134 ConversionPatternRewriter &rewriter) const {
2135 Location loc = op.getLoc();
2136 if (chipset != kGfx950)
2137 return rewriter.notifyMatchFailure(
2138 loc, "Scaled fp conversion instructions are not available on target "
2139 "architecture and their emulation is not implemented");
2140 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2141
2142 Value source = adaptor.getSource();
2143 Value scale = adaptor.getScale();
2144
2145 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2146 Type sourceElemType = sourceVecType.getElementType();
2147 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2148 Type destElemType = destVecType.getElementType();
2149
2150 VectorType packedVecType;
2151 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2152 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2153 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2154 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2155 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2156 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2157 } else {
2158 llvm_unreachable("invalid element type for scaled ext");
2159 }
2160
2161 // Extend to a packedVectorType
2162 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2163 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2164 if (!sourceVecType) {
2165 longVec = LLVM::InsertElementOp::create(
2166 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2167 } else {
2168 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2169 Value idx = createI32Constant(rewriter, loc, i);
2170 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2171 longVec =
2172 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2173 }
2174 }
2175 source = longVec;
2176 }
2177 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2178
2179 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2180 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2181 op, destVecType, i32Source, scale, op.getIndex());
2182 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2183 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2184 op, destVecType, i32Source, scale, op.getIndex());
2185 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2186 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2187 op, destVecType, i32Source, scale, op.getIndex());
2188 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2189 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2190 op, destVecType, i32Source, scale, op.getIndex());
2191 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2192 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2193 op, destVecType, i32Source, scale, op.getIndex());
2194 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2195 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2196 op, destVecType, i32Source, scale, op.getIndex());
2197 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2198 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2199 op, destVecType, i32Source, scale, op.getIndex());
2200 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2201 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2202 op, destVecType, i32Source, scale, op.getIndex());
2203 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2204 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2205 op, destVecType, i32Source, scale, op.getIndex());
2206 else
2207 return failure();
2208
2209 return success();
2210}
2211
2212LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2213 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2214 ConversionPatternRewriter &rewriter) const {
2215 Location loc = op.getLoc();
2216 if (chipset != kGfx950)
2217 return rewriter.notifyMatchFailure(
2218 loc, "Scaled fp conversion instructions are not available on target "
2219 "architecture and their emulation is not implemented");
2220 Type v2i16 = getTypeConverter()->convertType(
2221 VectorType::get(2, rewriter.getI16Type()));
2222 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2223
2224 Type resultType = op.getResult().getType();
2225 Type resultElemType = getElementTypeOrSelf(resultType);
2226 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2227 Type sourceElemType = sourceVecType.getElementType();
2228
2229 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2230
2231 Value source = adaptor.getSource();
2232 Value scale = adaptor.getScale();
2233 Value existing = adaptor.getExisting();
2234 if (existing)
2235 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2236 else
2237 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2238
2239 if (sourceVecType.getNumElements() < 2) {
2240 Value c0 = createI32Constant(rewriter, loc, 0);
2241 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2242 VectorType v2 = VectorType::get(2, sourceElemType);
2243 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2244 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2245 }
2246
2247 Value sourceA, sourceB;
2248 if (sourceElemType.isF32()) {
2249 Value c0 = createI32Constant(rewriter, loc, 0);
2250 Value c1 = createI32Constant(rewriter, loc, 1);
2251 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2252 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2253 }
2254
2255 Value result;
2256 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2257 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2258 existing, sourceA, sourceB,
2259 scale, op.getIndex());
2260 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2261 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2262 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2263 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2264 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2265 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2266 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2267 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2268 existing, sourceA, sourceB,
2269 scale, op.getIndex());
2270 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2271 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2272 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2273 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2274 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2275 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2276 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2277 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2278 existing, sourceA, sourceB,
2279 scale, op.getIndex());
2280 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2281 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2282 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2283 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2284 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2285 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2286 else
2287 return failure();
2288
2289 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2290 op, getTypeConverter()->convertType(resultType), result);
2291 return success();
2292}
2293
2294LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2295 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2296 ConversionPatternRewriter &rewriter) const {
2297 Location loc = op.getLoc();
2298 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2299 return rewriter.notifyMatchFailure(
2300 loc, "Fp8 conversion instructions are not available on target "
2301 "architecture and their emulation is not implemented");
2302 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2303
2304 Type resultType = op.getResult().getType();
2305 Type resultElemType = getElementTypeOrSelf(resultType);
2306
2307 Value sourceA = adaptor.getSourceA();
2308 Value sourceB = adaptor.getSourceB();
2309 if (!sourceB)
2310 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2311 Value existing = adaptor.getExisting();
2312 if (existing)
2313 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2314 else
2315 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2316
2317 Value result;
2318 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2319 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2320 existing, op.getWordIndex());
2321 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2322 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2323 existing, op.getWordIndex());
2324
2325 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2326 op, getTypeConverter()->convertType(resultType), result);
2327 return success();
2328}
2329
2330LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2331 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2332 ConversionPatternRewriter &rewriter) const {
2333 Location loc = op.getLoc();
2334 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2335 return rewriter.notifyMatchFailure(
2336 loc, "Fp8 conversion instructions are not available on target "
2337 "architecture and their emulation is not implemented");
2338 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2339
2340 Type resultType = op.getResult().getType();
2341 Type resultElemType = getElementTypeOrSelf(resultType);
2342
2343 Value source = adaptor.getSource();
2344 Value stoch = adaptor.getStochiasticParam();
2345 Value existing = adaptor.getExisting();
2346 if (existing)
2347 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2348 else
2349 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2350
2351 Value result;
2352 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2353 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2354 existing, op.getStoreIndex());
2355 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2356 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2357 existing, op.getStoreIndex());
2358
2359 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2360 op, getTypeConverter()->convertType(resultType), result);
2361 return success();
2362}
2363
2364// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2365// operation into the corresponding ROCDL instructions.
2366struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2367 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2368 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2369 Chipset chipset;
2370
2371 LogicalResult
2372 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2373 ConversionPatternRewriter &rewriter) const override {
2374
2375 // Convert the source operand to the corresponding LLVM type
2376 Location loc = DppOp.getLoc();
2377 Value src = adaptor.getSrc();
2378 Value old = adaptor.getOld();
2379 Type srcType = src.getType();
2380 Type oldType = old.getType();
2381 Type llvmType = nullptr;
2382 if (srcType.getIntOrFloatBitWidth() < 32) {
2383 llvmType = rewriter.getI32Type();
2384 } else if (isa<FloatType>(srcType)) {
2385 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2386 ? rewriter.getF32Type()
2387 : rewriter.getF64Type();
2388 } else if (isa<IntegerType>(srcType)) {
2389 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2390 ? rewriter.getI32Type()
2391 : rewriter.getI64Type();
2392 }
2393 auto llvmSrcIntType = typeConverter->convertType(
2394 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2395
2396 // If the source type is less of 32, use bitcast to convert it to i32.
2397 auto convertOperand = [&](Value operand, Type operandType) {
2398 if (operandType.getIntOrFloatBitWidth() <= 16) {
2399 if (llvm::isa<FloatType>(operandType)) {
2400 operand =
2401 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2402 }
2403 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2404 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2405 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2406 operand =
2407 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2408 createI32Constant(rewriter, loc, 0));
2409 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2410 }
2411 return operand;
2412 };
2413
2414 src = convertOperand(src, srcType);
2415 old = convertOperand(old, oldType);
2416
2417 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2418 enum DppCtrl : unsigned {
2419 ROW_SHL0 = 0x100,
2420 ROW_SHR0 = 0x110,
2421 ROW_ROR0 = 0x120,
2422 WAVE_SHL1 = 0x130,
2423 WAVE_ROL1 = 0x134,
2424 WAVE_SHR1 = 0x138,
2425 WAVE_ROR1 = 0x13C,
2426 ROW_MIRROR = 0x140,
2427 ROW_HALF_MIRROR = 0x141,
2428 BCAST15 = 0x142,
2429 BCAST31 = 0x143,
2430 };
2431
2432 auto kind = DppOp.getKind();
2433 auto permArgument = DppOp.getPermArgument();
2434 uint32_t DppCtrl = 0;
2435
2436 switch (kind) {
2437
2438 case DPPPerm::quad_perm: {
2439 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2440 int32_t i = 0;
2441 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2442 uint32_t num = elem.getInt();
2443 DppCtrl |= num << (i * 2);
2444 i++;
2445 }
2446 break;
2447 }
2448 case DPPPerm::row_shl: {
2449 auto intAttr = cast<IntegerAttr>(*permArgument);
2450 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2451 break;
2452 }
2453 case DPPPerm::row_shr: {
2454 auto intAttr = cast<IntegerAttr>(*permArgument);
2455 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2456 break;
2457 }
2458 case DPPPerm::row_ror: {
2459 auto intAttr = cast<IntegerAttr>(*permArgument);
2460 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2461 break;
2462 }
2463 case DPPPerm::wave_shl:
2464 DppCtrl = DppCtrl::WAVE_SHL1;
2465 break;
2466 case DPPPerm::wave_shr:
2467 DppCtrl = DppCtrl::WAVE_SHR1;
2468 break;
2469 case DPPPerm::wave_rol:
2470 DppCtrl = DppCtrl::WAVE_ROL1;
2471 break;
2472 case DPPPerm::wave_ror:
2473 DppCtrl = DppCtrl::WAVE_ROR1;
2474 break;
2475 case DPPPerm::row_mirror:
2476 DppCtrl = DppCtrl::ROW_MIRROR;
2477 break;
2478 case DPPPerm::row_half_mirror:
2479 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2480 break;
2481 case DPPPerm::row_bcast_15:
2482 DppCtrl = DppCtrl::BCAST15;
2483 break;
2484 case DPPPerm::row_bcast_31:
2485 DppCtrl = DppCtrl::BCAST31;
2486 break;
2487 }
2488
2489 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2490 // constants
2491 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2492 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2493 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2494
2495 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2496 auto dppMovOp =
2497 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2498 rowMask, bankMask, boundCtrl);
2499
2500 Value result = dppMovOp.getRes();
2501 if (srcType.getIntOrFloatBitWidth() < 32) {
2502 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2503 if (!llvm::isa<IntegerType>(srcType)) {
2504 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2505 }
2506 }
2507
2508 // We are replacing the AMDGPU_DPPOp instruction with the new
2509 // ROCDL_DPPMovOp instruction
2510 rewriter.replaceOp(DppOp, ValueRange(result));
2511 return success();
2512 }
2513};
2514
2515struct AMDGPUSwizzleBitModeLowering
2516 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2518
2519 LogicalResult
2520 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2521 ConversionPatternRewriter &rewriter) const override {
2522 Location loc = op.getLoc();
2523 Type i32 = rewriter.getI32Type();
2524 Value src = adaptor.getSrc();
2525 SmallVector<Value> decomposed =
2526 LLVM::decomposeValue(rewriter, loc, src, i32);
2527 unsigned andMask = op.getAndMask();
2528 unsigned orMask = op.getOrMask();
2529 unsigned xorMask = op.getXorMask();
2530
2531 // bit 15 is 0 for the BitMode swizzle.
2532 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2533 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2534 Value maskValue = createI32Constant(rewriter, loc, mask);
2535 SmallVector<Value> swizzled;
2536 for (Value v : decomposed) {
2537 Value res =
2538 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2539 swizzled.emplace_back(res);
2540 }
2541
2542 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2543 rewriter.replaceOp(op, result);
2544 return success();
2545 }
2546};
2547
2548struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2550
2551 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2552 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2553 Chipset chipset;
2554
2555 LogicalResult
2556 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2557 ConversionPatternRewriter &rewriter) const override {
2558 if (chipset < kGfx950)
2559 return op->emitOpError("permlane_swap is only supported on gfx950+");
2560
2561 Location loc = op.getLoc();
2562 Type i32 = rewriter.getI32Type();
2563 Value src = adaptor.getSrc();
2564 unsigned rowLength = op.getRowLength();
2565 bool fi = op.getFetchInactive();
2566 bool boundctrl = op.getBoundCtrl();
2567
2568 SmallVector<Value> decomposed =
2569 LLVM::decomposeValue(rewriter, loc, src, i32);
2570
2571 SmallVector<Value> permuted;
2572 for (Value v : decomposed) {
2573 Value res;
2574 Type i32pair = LLVM::LLVMStructType::getLiteral(
2575 rewriter.getContext(), {v.getType(), v.getType()});
2576
2577 if (rowLength == 16)
2578 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2579 boundctrl);
2580 else if (rowLength == 32)
2581 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2582 boundctrl);
2583 else
2584 llvm_unreachable("unsupported row length");
2585
2586 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2587 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2588
2589 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2590 LLVM::ICmpPredicate::eq, vdst0, v);
2591
2592 // Per `permlane(16|32)` semantics: if the first extracted element equals
2593 // 'v', the result is the second element; otherwise it is the first.
2594 Value vdstNew =
2595 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2596 permuted.emplace_back(vdstNew);
2597 }
2598
2599 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2600 rewriter.replaceOp(op, result);
2601 return success();
2602 }
2603};
2604
2605static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2606 Value accumulator, Value value, int64_t shift) {
2607 shift = shift % 32;
2608 Value shiftAmount;
2609 if (shift != 0) {
2610 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2611 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2612 }
2613
2614 if (matchPattern(accumulator, mlir::m_Zero()))
2615 return value;
2616
2617 constexpr bool isDisjoint = true;
2618 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2619}
2620
2621template <typename BaseOp>
2622struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
2623 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2624 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
2625
2626 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2627 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2628 Chipset chipset;
2629
2630 LogicalResult
2631 matchAndRewrite(BaseOp op, Adaptor adaptor,
2632 ConversionPatternRewriter &rewriter) const override {
2633 if (chipset < kGfx1250)
2634 return op->emitOpError("make_dma_base is only supported on gfx1250");
2635
2636 Location loc = op.getLoc();
2637
2638 constexpr int32_t constlen = 4;
2639 Value consts[constlen];
2640 for (int64_t i = 0; i < constlen; ++i)
2641 consts[i] = createI32Constant(rewriter, loc, i);
2642
2643 constexpr int32_t sgprslen = constlen;
2644 Value sgprs[sgprslen];
2645 for (int64_t i = 0; i < sgprslen; ++i) {
2646 sgprs[i] = consts[0];
2647 }
2648
2649 sgprs[0] = consts[1];
2650
2651 if constexpr (BaseOp::isGather()) {
2652 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2653
2654 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2655 Type indexType = type.getIndexType();
2656 unsigned indexSize = indexType.getIntOrFloatBitWidth();
2657 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2658 "expected index_size to be 16 or 32");
2659 unsigned idx = (indexSize / 16) - 1;
2660
2661 if (idx)
2662 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2663 }
2664
2665 ValueRange ldsIndices = adaptor.getLdsIndices();
2666 Value lds = adaptor.getLds();
2667 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2668
2670 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2671
2672 ValueRange globalIndices = adaptor.getGlobalIndices();
2673 Value global = adaptor.getGlobal();
2674 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2675
2677 rewriter, loc, globalMemRefType, global, globalIndices);
2678
2679 Type i32 = rewriter.getI32Type();
2680 Type i64 = rewriter.getI64Type();
2681
2682 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2683 Value castForGlobalAddr =
2684 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2685
2686 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2687
2688 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2689 createI64Constant(rewriter, loc, 32));
2690
2691 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2692
2693 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2694 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2695
2696 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2697
2698 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2699 assert(v4i32 && "expected type conversion to succeed");
2700 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2701
2702 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2703 result =
2704 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
2705
2706 rewriter.replaceOp(op, result);
2707 return success();
2708 }
2709};
2710
2711template <typename DescriptorOp>
2712struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
2713 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2714 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
2715
2716 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
2717 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2718 Chipset chipset;
2719
2720 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
2721
2722 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2723 ConversionPatternRewriter &rewriter, Location loc,
2724 Value sgpr0) const {
2725 Value mask = op.getWorkgroupMask();
2726 if (!mask)
2727 return sgpr0;
2728
2729 Type i16 = rewriter.getI16Type();
2730 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2731 Type i32 = rewriter.getI32Type();
2732 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2733 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2734 }
2735
2736 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2737 ConversionPatternRewriter &rewriter, Location loc,
2738 Value sgpr0, ArrayRef<Value> consts) const {
2739 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2740 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2741 "expected type width to be 8, 16, 32, or 64.");
2742 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
2743 Value size = consts[idx];
2744 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2745 }
2746
2747 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
2748 ConversionPatternRewriter &rewriter, Location loc,
2749 Value sgpr0, ArrayRef<Value> consts) const {
2750 if (!adaptor.getAtomicBarrierAddress())
2751 return sgpr0;
2752
2753 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2754 }
2755
2756 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
2757 ConversionPatternRewriter &rewriter, Location loc,
2758 Value sgpr0, ArrayRef<Value> consts) const {
2759 if (!adaptor.getGlobalIncrement())
2760 return sgpr0;
2761
2762 // Value is ignored when in gather mode.
2763 // TODO: emit error earlier?
2764 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2765 }
2766
2767 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
2768 ConversionPatternRewriter &rewriter, Location loc,
2769 Value sgpr0, ArrayRef<Value> consts) const {
2770 if (!op.getPadAmount())
2771 return sgpr0;
2772
2773 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2774 }
2775
2776 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
2777 ConversionPatternRewriter &rewriter, Location loc,
2778 Value sgpr0, ArrayRef<Value> consts) const {
2779 if (!op.getWorkgroupMask())
2780 return sgpr0;
2781
2782 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2783 }
2784
2785 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
2786 ConversionPatternRewriter &rewriter, Location loc,
2787 Value sgpr0, ArrayRef<Value> consts) const {
2788 if (!op.getPadAmount())
2789 return sgpr0;
2790
2791 // pre-condition: padInterval can be a power of two between 2 and 256.
2792 // TODO: Validation if the value breaks the pre-condition.
2793 // If the pre-condition fails, there is a possibility of
2794 // affecting the higher bits. In a following PR implement
2795 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2796 // checked at runtime.
2797 IntegerType i32 = rewriter.getI32Type();
2798 Value padInterval = adaptor.getPadInterval();
2799 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2800 padInterval, false);
2801 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2802 // post-condition: padInterval can be a value between 0 and 7.
2803 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2804 }
2805
2806 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
2807 ConversionPatternRewriter &rewriter, Location loc,
2808 Value sgpr0, ArrayRef<Value> consts) const {
2809 if (!op.getPadAmount())
2810 return sgpr0;
2811
2812 // pre-condition: padAmount is a value between 1-128.
2813 // TODO: Validation if the value breaks the pre-condition.
2814 // If the pre-condition fails, there is a possibility of
2815 // affecting the higher bits. In a following PR implement
2816 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2817 // checked at runtime.
2818 Value padAmount = adaptor.getPadAmount();
2819 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2820 // post-condition: padAmount is a value between 0-127.
2821 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2822 }
2823
2824 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
2825 ConversionPatternRewriter &rewriter,
2826 Location loc, Value sgpr1,
2827 ArrayRef<Value> consts) const {
2828 if (!adaptor.getAtomicBarrierAddress())
2829 return sgpr1;
2830
2831 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2832 auto barrierAddressTy =
2833 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2834 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2835 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
2836 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
2837 atomicBarrierIndices);
2838 IntegerType i32 = rewriter.getI32Type();
2839 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
2840 // that the 3 LSBs are zero.
2841 // TODO: Validation if the value breaks the pre-condition.
2842 // In a following PR implement RuntimeVerifiableOpInterface
2843 // that instruments conditions that need to be checked at runtime.
2844 atomicBarrierAddress =
2845 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2846 atomicBarrierAddress =
2847 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2848 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
2849 atomicBarrierAddress =
2850 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2851 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2852 }
2853
2854 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
2855 ConversionPatternRewriter &rewriter,
2856 Location loc, Value sgpr1, Value sgpr2,
2857 ArrayRef<Value> consts, uint64_t dimX,
2858 uint32_t offset) const {
2859 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
2860 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
2861 SmallVector<OpFoldResult> mixedGlobalSizes =
2862 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
2863 if (mixedGlobalSizes.size() <= dimX)
2864 return {sgpr1, sgpr2};
2865
2866 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
2867 // pre-condition: tensorDimX is less than 2^32-1
2868 // TODO: Validation if the value breaks the pre-condition.
2869 // In a following PR implement RuntimeVerifiableOpInterface that instruments
2870 // conditions that need to be checked at runtime. This could also be fixed
2871 // by saying that mixedGlobalSizes is a DynamicI32List.
2872 Value tensorDimX;
2873 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
2874 tensorDimX =
2875 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2876 } else {
2877 IntegerType i32 = rewriter.getI32Type();
2878 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
2879 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
2880 }
2881
2882 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
2883
2884 Value c16 = createI32Constant(rewriter, loc, 16);
2885 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
2886 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
2887 return {sgpr1, sgpr2};
2888 }
2889
2890 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
2891 ConversionPatternRewriter &rewriter,
2892 Location loc, Value sgpr1, Value sgpr2,
2893 ArrayRef<Value> consts) const {
2894 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
2895 48);
2896 }
2897
2898 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
2899 ConversionPatternRewriter &rewriter,
2900 Location loc, Value sgpr2, Value sgpr3,
2901 ArrayRef<Value> consts) const {
2902 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
2903 80);
2904 }
2905
2906 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
2907 ConversionPatternRewriter &rewriter, Location loc,
2908 Value sgpr, ArrayRef<Value> consts, size_t dimX,
2909 int64_t offset) const {
2910 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
2911 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
2912 SmallVector<OpFoldResult> mixedSharedSizes =
2913 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
2914 if (mixedSharedSizes.size() <= dimX)
2915 return sgpr;
2916
2917 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2918 // pre-condition: tileDimX is less than 2^16-1
2919 // TODO: Validation if the value breaks the pre-condition.
2920 // If the pre-condition fails, there is a possibility of
2921 // affecting the higher bits. In a following PR implement
2922 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2923 // checked at runtime. This could also be fixed by saying that
2924 // mixedSharedSizes is a DynamicI16List.
2925 Value tileDimX;
2926 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
2927 tileDimX =
2928 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2929 } else {
2930 IntegerType i32 = rewriter.getI32Type();
2931 tileDimX = cast<Value>(tileDimXOpFoldResult);
2932 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
2933 }
2934
2935 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2936 }
2937
2938 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
2939 ConversionPatternRewriter &rewriter, Location loc,
2940 Value sgpr3, ArrayRef<Value> consts) const {
2941 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2942 }
2943
2944 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
2945 ConversionPatternRewriter &rewriter, Location loc,
2946 Value sgpr4, ArrayRef<Value> consts) const {
2947 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2948 }
2949
2950 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
2951 ConversionPatternRewriter &rewriter, Location loc,
2952 Value sgpr4, ArrayRef<Value> consts) const {
2953 auto type = cast<VectorType>(op.getIndices().getType());
2954 ArrayRef<int64_t> shape = type.getShape();
2955 assert(shape.size() == 1 && "expected shape to be of rank 1.");
2956 unsigned length = shape.back();
2957 assert(0 < length && length <= 16 && "expected length to be at most 16.");
2958 Value value = createI32Constant(rewriter, loc, length);
2959 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
2960 }
2961
2962 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
2963 ConversionPatternRewriter &rewriter,
2964 Location loc, Value sgpr4,
2965 ArrayRef<Value> consts) const {
2966 if constexpr (DescriptorOp::isGather())
2967 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
2968 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
2969 }
2970
2971 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
2972 ConversionPatternRewriter &rewriter, Location loc,
2973 Value sgpr4, ArrayRef<Value> consts) const {
2974 // Value is ignored when in gather mode.
2975 if constexpr (DescriptorOp::isGather())
2976 return sgpr4;
2977 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2978 }
2979
2980 std::pair<Value, Value>
2981 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
2982 ConversionPatternRewriter &rewriter, Location loc,
2983 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2984 size_t dimX, int64_t offset) const {
2985 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
2986 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
2987 SmallVector<OpFoldResult> mixedGlobalStrides =
2988 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
2989
2990 if (mixedGlobalStrides.size() <= (dimX + 1))
2991 return {sgprY, sgprZ};
2992
2993 OpFoldResult tensorDimXStrideOpFoldResult =
2994 *(mixedGlobalStrides.rbegin() + dimX + 1);
2995 // pre-condition: tensorDimXStride is less than 2^48-1
2996 // TODO: Validation if the value breaks the pre-condition.
2997 // In a following PR implement RuntimeVerifiableOpInterface that instruments
2998 // conditions that need to be checked at runtime.
2999 Value tensorDimXStride;
3000 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3001 tensorDimXStride =
3002 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3003 else
3004 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3005
3006 constexpr int64_t first48bits = (1ll << 48) - 1;
3007 Value mask = createI64Constant(rewriter, loc, first48bits);
3008 tensorDimXStride =
3009 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3010 IntegerType i32 = rewriter.getI32Type();
3011 Value tensorDimXStrideLow =
3012 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3013 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3014
3015 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3016 Value shiftVal = createI64Constant(rewriter, loc, shift);
3017 Value tensorDimXStrideHigh =
3018 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3019 tensorDimXStrideHigh =
3020 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3021 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3022 offset + shift);
3023 return {sgprY, sgprZ};
3024 }
3025
3026 std::pair<Value, Value>
3027 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3028 ConversionPatternRewriter &rewriter, Location loc,
3029 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3030 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3031 0, 160);
3032 }
3033
3034 std::pair<Value, Value>
3035 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3036 ConversionPatternRewriter &rewriter, Location loc,
3037 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3038 // Value is ignored when in gather mode.
3039 if constexpr (DescriptorOp::isGather())
3040 return {sgpr5, sgpr6};
3041 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3042 1, 208);
3043 }
3044
3045 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3046 ConversionPatternRewriter &rewriter, Location loc,
3047 ArrayRef<Value> consts) const {
3048 Value sgprs[8];
3049 for (int64_t i = 0; i < 8; ++i) {
3050 sgprs[i] = consts[0];
3051 }
3052
3053 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3054 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3055 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3056 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3057 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3058 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3059 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3060 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3061
3062 sgprs[1] =
3063 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3064 std::tie(sgprs[1], sgprs[2]) =
3065 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3066 std::tie(sgprs[2], sgprs[3]) =
3067 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3068
3069 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3070 sgprs[4] =
3071 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3072 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3073 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3074 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3075 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3076 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3077
3078 IntegerType i32 = rewriter.getI32Type();
3079 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3080 assert(v8i32 && "expected type conversion to succeed");
3081 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3082
3083 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3084 dgroup1 =
3085 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3086 }
3087
3088 return dgroup1;
3089 }
3090
3091 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3092 ConversionPatternRewriter &rewriter, Location loc,
3093 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3094 int64_t offset) const {
3095 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3096 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3097 SmallVector<OpFoldResult> mixedGlobalSizes =
3098 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3099 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3100 return sgpr0;
3101
3102 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3103 Value tensorDimX;
3104 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3105 tensorDimX =
3106 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3107 } else {
3108 IntegerType i32 = rewriter.getI32Type();
3109 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3110 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3111 }
3112
3113 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3114 }
3115
3116 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3117 ConversionPatternRewriter &rewriter, Location loc,
3118 Value sgpr0, ArrayRef<Value> consts) const {
3119 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3120 }
3121
3122 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3123 Location loc, Value accumulator,
3124 Value value, int64_t shift) const {
3125
3126 IntegerType i32 = rewriter.getI32Type();
3127 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3128 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3129 }
3130
3131 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3132 ConversionPatternRewriter &rewriter, Location loc,
3133 Value sgpr1, ArrayRef<Value> consts,
3134 int64_t offset) const {
3135 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3136 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3137 }
3138
3139 std::pair<Value, Value>
3140 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3141 ConversionPatternRewriter &rewriter, Location loc,
3142 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3143 int64_t offset) const {
3144 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3145 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3146 globalAddrIncrement, offset);
3147 Value shift = createI64Constant(rewriter, loc, 32);
3148 globalAddrIncrement =
3149 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3150 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3151 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3152 globalAddrIncrement, offset + 32);
3153 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3154 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3155 return {sgpr2, sgpr3};
3156 }
3157
3158 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3159 ConversionPatternRewriter &rewriter,
3160 Location loc, Value sgpr1,
3161 ArrayRef<Value> consts) const {
3162 Value ldsIncrement = op.getLdsIncrement();
3163 constexpr int64_t dim = 3;
3164 constexpr int64_t offset = 32;
3165 if (!ldsIncrement)
3166 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3167 offset);
3168 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3169 offset);
3170 }
3171
3172 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3173 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3174 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3175 Value globalIncrement = op.getGlobalIncrement();
3176 constexpr int32_t dim = 2;
3177 constexpr int32_t offset = 64;
3178 if (!globalIncrement)
3179 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3180 consts, dim, offset);
3181 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3182 consts, offset);
3183 }
3184
3185 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3186 ConversionPatternRewriter &rewriter, Location loc,
3187 Value sgpr3, ArrayRef<Value> consts,
3188 int32_t offset) const {
3189 Value iterationCount = adaptor.getIterationCount();
3190 IntegerType i32 = rewriter.getI32Type();
3191 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3192 // TODO: validation if the value breaks the pre-condition.
3193 // If the pre-condition fails, there is a possibility of
3194 // affecting the higher bits. In a following PR implement
3195 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3196 // checked at runtime.
3197 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3198 iterationCount =
3199 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3200 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3201 }
3202
3203 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3204 ConversionPatternRewriter &rewriter,
3205 Location loc, Value sgpr3,
3206 ArrayRef<Value> consts) const {
3207 Value iterateCount = op.getIterationCount();
3208 constexpr int32_t dim = 2;
3209 constexpr int32_t offset = 112;
3210 if (!iterateCount)
3211 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3212 offset);
3213
3214 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3215 }
3216
3217 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3218 ConversionPatternRewriter &rewriter, Location loc,
3219 ArrayRef<Value> consts) const {
3220 if constexpr (DescriptorOp::isGather())
3221 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3222 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3223 }
3224
3225 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3226 ConversionPatternRewriter &rewriter, Location loc,
3227 ArrayRef<Value> consts) const {
3228 IntegerType i32 = rewriter.getI32Type();
3229 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3230 assert(v4i32 && "expected type conversion to succeed.");
3231
3232 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3233 if (onlyNeedsTwoDescriptors)
3234 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3235
3236 constexpr int64_t sgprlen = 4;
3237 Value sgprs[sgprlen];
3238 for (int i = 0; i < sgprlen; ++i)
3239 sgprs[i] = consts[0];
3240
3241 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3242 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3243 sgprs[1], consts);
3244 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3245 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3246 sgprs[3] =
3247 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3248
3249 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3250 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3251 dgroup2 =
3252 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3253
3254 return dgroup2;
3255 }
3256
3257 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3258 ConversionPatternRewriter &rewriter, Location loc,
3259 ArrayRef<Value> consts, bool firstHalf) const {
3260 IntegerType i32 = rewriter.getI32Type();
3261 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3262 assert(v4i32 && "expected type conversion to succeed.");
3263
3264 Value indices = adaptor.getIndices();
3265 auto vectorType = cast<VectorType>(indices.getType());
3266 unsigned length = vectorType.getShape().back();
3267 Type elementType = vectorType.getElementType();
3268 unsigned maxLength = elementType == i32 ? 4 : 8;
3269 int32_t offset = firstHalf ? 0 : maxLength;
3270 unsigned discountedLength =
3271 std::max(static_cast<int32_t>(length - offset), 0);
3272
3273 unsigned targetSize = std::min(maxLength, discountedLength);
3274
3275 SmallVector<Value> indicesVector;
3276 for (unsigned i = offset; i < targetSize + offset; ++i) {
3277 Value idx;
3278 if (i < consts.size())
3279 idx = consts[i];
3280 else
3281 idx = createI32Constant(rewriter, loc, i);
3282 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
3283 indicesVector.push_back(elem);
3284 }
3285
3286 SmallVector<Value> indicesI32Vector;
3287 if (elementType == i32) {
3288 indicesI32Vector = indicesVector;
3289 } else {
3290 for (unsigned i = 0; i < targetSize; ++i) {
3291 Value index = indicesVector[i];
3292 indicesI32Vector.push_back(
3293 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3294 }
3295 if ((targetSize % 2) != 0)
3296 // Add padding when not divisible by two.
3297 indicesI32Vector.push_back(consts[0]);
3298 }
3299
3300 SmallVector<Value> indicesToInsert;
3301 if (elementType == i32) {
3302 indicesToInsert = indicesI32Vector;
3303 } else {
3304 unsigned size = indicesI32Vector.size() / 2;
3305 for (unsigned i = 0; i < size; ++i) {
3306 Value first = indicesI32Vector[2 * i];
3307 Value second = indicesI32Vector[2 * i + 1];
3308 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3309 indicesToInsert.push_back(joined);
3310 }
3311 }
3312
3313 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3314 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3315 dgroup =
3316 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3317
3318 return dgroup;
3319 }
3320
3321 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3322 ConversionPatternRewriter &rewriter, Location loc,
3323 ArrayRef<Value> consts) const {
3324 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
3325 }
3326
3327 std::pair<Value, Value>
3328 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3329 ConversionPatternRewriter &rewriter, Location loc,
3330 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
3331 constexpr int32_t dim = 3;
3332 constexpr int32_t offset = 0;
3333 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3334 dim, offset);
3335 }
3336
3337 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3338 ConversionPatternRewriter &rewriter,
3339 Location loc, Value sgpr1, Value sgpr2,
3340 ArrayRef<Value> consts) const {
3341 constexpr int32_t dim = 4;
3342 constexpr int32_t offset = 48;
3343 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3344 offset);
3345 }
3346
3347 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3348 ConversionPatternRewriter &rewriter, Location loc,
3349 Value sgpr2, ArrayRef<Value> consts) const {
3350 constexpr int32_t dim = 4;
3351 constexpr int32_t offset = 80;
3352 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3353 }
3354
3355 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3356 ConversionPatternRewriter &rewriter, Location loc,
3357 ArrayRef<Value> consts) const {
3358 if constexpr (DescriptorOp::isGather())
3359 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3360 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3361 }
3362
3363 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3364 ConversionPatternRewriter &rewriter, Location loc,
3365 ArrayRef<Value> consts) const {
3366 IntegerType i32 = rewriter.getI32Type();
3367 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3368 assert(v4i32 && "expected type conversion to succeed.");
3369 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3370 if (onlyNeedsTwoDescriptors)
3371 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3372
3373 constexpr int32_t sgprlen = 4;
3374 Value sgprs[sgprlen];
3375 for (int i = 0; i < sgprlen; ++i)
3376 sgprs[i] = consts[0];
3377
3378 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3379 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3380 std::tie(sgprs[1], sgprs[2]) =
3381 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3382 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3383
3384 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3385 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3386 dgroup3 =
3387 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3388
3389 return dgroup3;
3390 }
3391
3392 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3393 ConversionPatternRewriter &rewriter, Location loc,
3394 ArrayRef<Value> consts) const {
3395 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
3396 }
3397
3398 LogicalResult
3399 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3400 ConversionPatternRewriter &rewriter) const override {
3401 if (chipset < kGfx1250)
3402 return op->emitOpError(
3403 "make_dma_descriptor is only supported on gfx1250");
3404
3405 Location loc = op.getLoc();
3406
3407 SmallVector<Value> consts;
3408 for (int64_t i = 0; i < 8; ++i)
3409 consts.push_back(createI32Constant(rewriter, loc, i));
3410
3411 Value dgroup0 = this->getDGroup0(adaptor);
3412 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3413 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3414 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3415 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3416 rewriter.replaceOpWithMultiple(op, {results});
3417 return success();
3418 }
3419};
3420
3421template <typename SourceOp, typename TargetOp>
3422struct AMDGPUTensorLoadStoreOpLowering
3423 : public ConvertOpToLLVMPattern<SourceOp> {
3424 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3426 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
3427 Chipset chipset)
3428 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3429 Chipset chipset;
3430
3431 LogicalResult
3432 matchAndRewrite(SourceOp op, Adaptor adaptor,
3433 ConversionPatternRewriter &rewriter) const override {
3434 if (chipset < kGfx1250)
3435 return op->emitOpError("is only supported on gfx1250");
3436
3437 ValueRange desc = adaptor.getDesc();
3438 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3439 desc[3], /*cachePolicy=*/0,
3440 /*alias_scopes=*/nullptr,
3441 /*noalias_scopes=*/nullptr,
3442 /*tbaa=*/nullptr);
3443 return success();
3444 }
3445};
3446
3447struct ConvertAMDGPUToROCDLPass
3448 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3449 using Base::Base;
3450
3451 void runOnOperation() override {
3452 MLIRContext *ctx = &getContext();
3453 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
3454 if (failed(maybeChipset)) {
3455 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
3456 return signalPassFailure();
3457 }
3458
3459 RewritePatternSet patterns(ctx);
3460 LLVMTypeConverter converter(ctx);
3461
3462 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
3463 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3464 LLVMConversionTarget target(getContext());
3465 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3467 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3468 if (failed(applyPartialConversion(getOperation(), target,
3469 std::move(patterns))))
3470 signalPassFailure();
3471 }
3472};
3473} // namespace
3474
3476 TypeConverter &typeConverter) {
3478 typeConverter, [](gpu::AddressSpace space) {
3479 switch (space) {
3480 case gpu::AddressSpace::Global:
3481 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3482 case gpu::AddressSpace::Workgroup:
3483 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3484 case gpu::AddressSpace::Private:
3485 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3486 }
3487 llvm_unreachable("unknown address space enum value");
3488 });
3489}
3490
3492 TypeConverter &typeConverter) {
3493 typeConverter.addTypeAttributeConversion(
3494 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
3495 -> TypeConverter::AttributeConversionResult {
3496 MLIRContext *ctx = as.getContext();
3497 Type i64 = IntegerType::get(ctx, 64);
3498 switch (as.getValue()) {
3499 case amdgpu::AddressSpace::FatRawBuffer:
3500 return IntegerAttr::get(i64, 7);
3501 case amdgpu::AddressSpace::BufferRsrc:
3502 return IntegerAttr::get(i64, 8);
3503 case amdgpu::AddressSpace::FatStructuredBuffer:
3504 return IntegerAttr::get(i64, 9);
3505 }
3506 return TypeConverter::AttributeConversionResult::abort();
3507 });
3508 typeConverter.addConversion([&](TDMBaseType type) -> Type {
3509 Type i32 = IntegerType::get(type.getContext(), 32);
3510 return typeConverter.convertType(VectorType::get(4, i32));
3511 });
3512 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
3513 Type i32 = IntegerType::get(type.getContext(), 32);
3514 return typeConverter.convertType(VectorType::get(4, i32));
3515 });
3516 typeConverter.addConversion(
3517 [&](TDMDescriptorType type,
3518 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
3519 Type i32 = IntegerType::get(type.getContext(), 32);
3520 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3521 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3522 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
3523 return success();
3524 });
3525
3526 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
3527 ValueRange inputs,
3529 // Only create unrealized_conversion_cast for TDMDescriptorType.
3530 // All other types which are not expected, should be
3531 // materialized by other target materialization functions.
3532 if (inputs.size() != 1)
3533 return {};
3534
3535 if (!isa<TDMDescriptorType>(inputs[0].getType()))
3536 return {};
3537
3538 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3539 return cast.getResults();
3540 };
3541
3542 typeConverter.addTargetMaterialization(addUnrealizedCast);
3543}
3544
3547 Chipset chipset) {
3549 patterns
3550 .add<FatRawBufferCastLowering,
3551 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3552 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3553 RawBufferOpLowering<RawBufferAtomicFaddOp,
3554 ROCDL::RawPtrBufferAtomicFaddOp>,
3555 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3556 ROCDL::RawPtrBufferAtomicFmaxOp>,
3557 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3558 ROCDL::RawPtrBufferAtomicSmaxOp>,
3559 RawBufferOpLowering<RawBufferAtomicUminOp,
3560 ROCDL::RawPtrBufferAtomicUminOp>,
3561 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3562 ROCDL::RawPtrBufferAtomicCmpSwap>,
3563 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3564 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3565 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3566 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3567 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3568 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3569 GatherToLDSOpLowering, TransposeLoadOpLowering,
3570 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3571 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3572 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3573 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3574 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3575 ROCDL::TensorLoadToLDSOp>,
3576 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3577 ROCDL::TensorStoreFromLDSOp>>(
3578 converter, chipset);
3579 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
3580}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:219
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:218
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14