MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Matchers.h"
25#include "mlir/Pass/Pass.h"
26
28
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
33#include <optional>
34
35namespace mlir {
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
38} // namespace mlir
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43// Define commonly used chipsets versions for convenience.
44constexpr Chipset kGfx908 = Chipset(9, 0, 8);
45constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
46constexpr Chipset kGfx942 = Chipset(9, 4, 2);
47constexpr Chipset kGfx950 = Chipset(9, 5, 0);
48constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
49
50/// Convert an unsigned number `val` to i32.
51static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
52 Location loc, Value val) {
53 IntegerType i32 = rewriter.getI32Type();
54 // Force check that `val` is of int type.
55 auto valTy = cast<IntegerType>(val.getType());
56 if (i32 == valTy)
57 return val;
58 return valTy.getWidth() > 32
59 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61}
62
63static Value createI32Constant(ConversionPatternRewriter &rewriter,
64 Location loc, int32_t value) {
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
66}
67
68/// Convert an unsigned number `val` to i64.
69static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
70 Location loc, Value val) {
71 IntegerType i64 = rewriter.getI64Type();
72 // Force check that `val` is of int type.
73 auto valTy = cast<IntegerType>(val.getType());
74 if (i64 == valTy)
75 return val;
76 return valTy.getWidth() > 64
77 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
79}
80
81static Value createI64Constant(ConversionPatternRewriter &rewriter,
82 Location loc, int64_t value) {
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
84}
85
86/// Returns the linear index used to access an element in the memref.
87static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
88 Location loc, MemRefDescriptor &memRefDescriptor,
90 IntegerType i32 = rewriter.getI32Type();
92 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
93 if (stride != 1) { // Skip if stride is 1.
94 Value strideValue =
95 ShapedType::isDynamic(stride)
96 ? convertUnsignedToI32(rewriter, loc,
97 memRefDescriptor.stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
100 }
101 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
102 : increment;
103 }
104 return index ? index : createI32Constant(rewriter, loc, 0);
105}
106
107/// Compute the contents of the `num_records` field for a given memref
108/// descriptor - that is, the number of bytes that's one element past the
109/// greatest possible valid index into the memref.
110static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
111 MemRefType memrefType,
112 MemRefDescriptor &memrefDescriptor,
113 ArrayRef<int64_t> strides, int64_t elementByteWidth,
114 amdgpu::Chipset chipset, bool boundsCheck) {
115 if (chipset >= kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
117 return createI64Constant(rewriter, loc, first45bits);
118 }
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
122 ArrayRef<int64_t> shape = memrefType.getShape();
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(shape[i] * strides[i], size);
125 size = size * elementByteWidth;
126 return createI64Constant(rewriter, loc, size);
127 }
128 Value maxIndex;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.size(rewriter, loc, i);
131 Value stride = memrefDescriptor.stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
133 maxIndex = maxIndex
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
135 : maxThisDim;
136 }
137 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
138 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
140}
141
142static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
143 Value basePointer, Value numRecords,
144 bool boundsCheck, amdgpu::Chipset chipset,
145 Value cacheSwizzleStride = nullptr,
146 unsigned addressSpace = 8) {
147 // The stride value is generally 0. However, on MI-300 and onward, you can
148 // enable a cache swizzling mode by setting bit 14 of the stride field
149 // and setting that stride to a cache stride.
150 Type i16 = rewriter.getI16Type();
151 Value stride;
152 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
158 /*isDisjoint=*/true);
159 } else {
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
162 }
163
164 uint32_t flags = 0;
165 if (chipset >= kGfx1250) {
166 // Flag word:
167 // bit 0: swizzle
168 // bit 1: 0 means (total_offset + payload > numRecords)
169 // 1 means ((total_offset + payload >) numRecords) || ((offset +
170 // payload) > stride) only applied when swizzle_enable = 0. keep at
171 // zero.
172 // whether oob is done depends on numRecords.
173 // bits 2-3: Type (must be 0)
174 } else {
175 // Get the number of elements.
176 // Flag word:
177 // bits 0-11: dst sel, ignored by these intrinsics
178 // bits 12-14: data format (ignored, must be nonzero, 7=float)
179 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
180 // bit 19: In nested heap (0 here)
181 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
182 // bits 21-22: Index stride for swizzles (N/A)
183 // bit 23: Add thread ID (0)
184 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
185 // bits 25-26: Reserved (0)
186 // bit 27: Buffer is non-volatile (CDNA only)
187 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
188 // none, 3 = either swizzles or testing against offset field) RDNA only
189 // bits 30-31: Type (must be 0)
190 flags |= (7 << 12) | (4 << 15);
191 if (chipset.majorVersion >= 10) {
192 flags |= (1 << 24);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
195 }
196 }
197 Value flagsConst = createI32Constant(rewriter, loc, flags);
198 Type rsrcType =
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
202 return resource;
203}
204
205namespace {
206struct FatRawBufferCastLowering
207 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
208 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
210 chipset(chipset) {}
211
212 Chipset chipset;
213
214 LogicalResult
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
221 MemRefDescriptor descriptor(memRef);
222
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
225 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
226
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError("Can't lower non-stride-offset memrefs");
231
232 Value numRecords = adaptor.getValidBytes();
233 if (!numRecords)
234 numRecords =
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
237
238 Value basePointer =
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
241 memrefType)
242 : descriptor.alignedPtr(rewriter, loc);
243
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
248
249 bool hasSizes = memrefType.getRank() > 0;
250 // No need to unpack() and pack() all the individual sizes and strides,
251 // so we'll just extract the arrays.
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
255 : Value{};
256 Value strides =
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
259 : Value{};
260
261 Value fatPtr = makeBufferRsrc(
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
264
265 Value result = MemRefDescriptor::poison(
266 rewriter, loc,
267 getTypeConverter()->convertType(op.getResult().getType()));
268 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
269 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
274 if (hasSizes) {
275 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
279 }
280 rewriter.replaceOp(op, result);
281 return success();
282 }
283};
284
285/// Define lowering patterns for raw buffer ops
286template <typename GpuOp, typename Intrinsic>
287struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
288 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
290
291 Chipset chipset;
292 static constexpr uint32_t maxVectorOpWidth = 128;
293
294 LogicalResult
295 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
301
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
304
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref) // no write component to this op
307 storeData = Value();
308 Type wantedDataType;
309 if (storeData)
310 wantedDataType = storeData.getType();
311 else
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
313
314 Value atomicCmpData = Value();
315 // Operand index 1 of a load is the indices, trying to read them can crash.
316 if (storeData) {
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
320 }
321
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
323
324 Type i32 = rewriter.getI32Type();
325
326 // Get the type size in bytes.
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
329 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
330 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
331
332 // If we want to load a vector<NxT> with total size <= 32
333 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
334 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
335 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
336 // so bitcast any floats to integers.
337 Type llvmBufferValType = llvmWantedDataType;
338 if (atomicCmpData) {
339 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
342 }
343 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
345 uint32_t elemBits =
346 dataLayout.getTypeSizeInBits(dataVector.getElementType());
347 uint32_t totalBits = elemBits * vecLen;
348 bool usePackedFp16 =
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) + " bits, but we call for " +
354 Twine(totalBits) +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError("Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
363 } else {
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
366 }
367 }
368 }
369 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
370 // Buffer intrinsics doesn't support 1-element vectors, cast them to
371 // scalars.
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
374 }
375
376 SmallVector<Value, 6> args;
377 if (storeData) {
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
382 } else {
383 args.push_back(storeData);
384 }
385 }
386
387 if (atomicCmpData) {
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
392 } else {
393 args.push_back(atomicCmpData);
394 }
395 }
396
397 // Construct buffer descriptor from memref, attributes
398 int64_t offset = 0;
399 SmallVector<int64_t, 5> strides;
400 if (failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
402
403 MemRefDescriptor memrefDescriptor(memref);
404
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
407 Value numRecords =
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
410 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
413
414 // Indexing (voffset)
415 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
419 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
421 extraOffsetConst)
422 : extraOffsetConst;
423 }
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
426
427 // SGPR offset.
428 Value sgprOffset = adaptor.getSgprOffset();
429 if (!sgprOffset)
430 sgprOffset = createI32Constant(rewriter, loc, 0);
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
433
434 // bit 0: GLC = 0 (atomics drop value, less coherency)
435 // bits 1-2: SLC, DLC = 0 (similarly)
436 // bit 3: swizzled (0 for raw)
437 args.push_back(createI32Constant(rewriter, loc, 0));
438
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
440 llvmBufferValType);
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
443 if (lowered->getNumResults() == 1) {
444 Value replacement = lowered->getResult(0);
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
448 }
449 rewriter.replaceOp(gpuOp, replacement);
450 } else {
451 rewriter.eraseOp(gpuOp);
452 }
453 return success();
454 }
455};
456
457// TODO: AMDGPU backend already have all this bitpacking logic, we should move
458// it to some common place.
459/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
460/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
461/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
462/// Vmcnt = Waitcnt[15:10] (gfx11)
463/// Expcnt = Waitcnt[6:4] (pre-gfx11)
464/// Expcnt = Waitcnt[2:0] (gfx11)
465/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
466/// Lgkmcnt = Waitcnt[13:8] (gfx10)
467/// Lgkmcnt = Waitcnt[9:4] (gfx11)
468static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
469 unsigned expcnt, unsigned lgkmcnt) {
470 if (chipset.majorVersion < 9) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
475 }
476 if (chipset.majorVersion == 9) {
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
485 if (chipset.majorVersion == 10) {
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
493 }
494 if (chipset.majorVersion == 11) {
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
499 }
500 return failure();
501}
502
503struct MemoryCounterWaitOpLowering
504 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
505 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
506 Chipset chipset)
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
508 chipset(chipset) {}
509
511
512 LogicalResult
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
519
520 if (std::optional<int> load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
522
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
525
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
528
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
532 rewriter.eraseOp(op);
533 return success();
534 }
535
536 if (adaptor.getTensor())
537 return op.emitOpError("unsupported chipset");
538
539 auto getVal = [](Attribute attr) -> unsigned {
540 if (attr)
541 return cast<IntegerAttr>(attr).getInt();
542
543 // This value will be clamped to the maximum value for the chipset.
544 return 1024;
545 };
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
548
549 unsigned vmcnt = 1024;
550 Attribute load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
552 if (load && store) {
553 vmcnt = getVal(load) + getVal(store);
554 } else if (load) {
555 vmcnt = getVal(load);
556 } else if (store) {
557 vmcnt = getVal(store);
558 }
559
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
561 if (failed(waitcnt))
562 return op.emitOpError("unsupported chipset");
563
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
565 return success();
566 }
567};
568
569struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
570 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
572
573 Chipset chipset;
574
575 LogicalResult
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter) const override {
578 Location loc = op.getLoc();
579 // This ensures that waits on global memory aren't introduced on
580 // chips that don't have the BackOffBarrier feature enabled in LLVM.
581 bool requiresInlineAsm = chipset < kGfx90a;
582
583 Attribute mmra =
584 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
585 // Note: while there *is* a workgroup-one-as scope, this, when combined with
586 // the MMRA, will lead to the fence having no effect. This is because the
587 // codepaths for an atomic load or store will observe that a
588 // one-address-space atomic to LDS requires no synchronization because
589 // operations on LDS are totally ordered with respect to each other, and so
590 // will not emit the correct waitcnt operations that these fences are
591 // intended to produce. Therefore, we use a broader type of fence and rely
592 // on the MMRA to relax it to the semantics we want.
593 StringRef scope = "workgroup";
594
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints = "";
603 LLVM::InlineAsmOp::create(
604 rewriter, loc,
605 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
606 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
607 /*is_align_stack=*/false, LLVM::TailCallKind::None,
608 /*asm_dialect=*/asmDialectAttr,
609 /*operand_attrs=*/ArrayAttr());
610 } else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
612 } else {
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
615 }
616
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
621 return success();
622 }
623};
624
625struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
626 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
628
629 Chipset chipset;
630
631 LogicalResult
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
636 return success();
637 }
638};
639
640} // namespace
641
642/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
643/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
644///
645/// Specifically:
646/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
647/// allows bf16. Newer MFMAs support bf16 types on operand, check
648/// IntrinsicsAMDGPU.td file for reference.
649/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
650/// instead, which is what the f8f6f4 intrinsics use.
651/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
652/// integer.
653///
654/// Note that the type of `input` has already been LLVM type converted:
655/// therefore 8-bit and smaller floats are represented as their corresponding
656/// `iN` integers.
657static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
658 Location loc, Value input,
659 bool allowBf16 = true) {
660 Type inputType = input.getType();
661 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
668 rewriter, loc,
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
674 32);
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
677 input);
678 }
679 }
680 return input;
681}
682
683/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
684static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
685 Location loc, Value input,
686 bool allowBf16 = true) {
687 Type inputType = input.getType();
688 auto vectorType = cast<VectorType>(inputType);
689 // bf16 -> i16 when not allowed (pre-gfx950).
690 if (vectorType.getElementType().isBF16() && !allowBf16)
691 return LLVM::BitcastOp::create(
692 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
693 // i8/fp8 vectors -> vector<Nxi32>.
694 if (isa<IntegerType>(vectorType.getElementType()) &&
695 vectorType.getElementTypeBitWidth() <= 8) {
696 int64_t numWords = llvm::divideCeil(
697 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
698 return LLVM::BitcastOp::create(
699 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
700 }
701 return input;
702}
703
704/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
705/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
706///
707/// Specifically:
708/// 1. If `input` is a i8 value, zero extend it to i32
709/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
710///
711/// Note that the type of `input` has already been LLVM type converted:
712/// therefore 8-bit and smaller floats are represented as their corresponding
713/// `iN` integers.
714static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
715 Value input) {
716 return TypeSwitch<Type, Value>(input.getType())
717 .Case([&](IntegerType) {
718 // Handle scalar i8: zero extend to i32.
719 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
720 input);
721 })
722 .Case([&](VectorType vectorType) {
723 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
724 int64_t numElements = vectorType.getNumElements();
725 assert((numElements == 4 || numElements == 8) &&
726 "scale operand must be a vector of length 4 or 8");
727 IntegerType outputType =
728 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
729 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
730 })
731 .DefaultUnreachable("unexpected input type for scale operand");
732}
733
734/// Maps f8 scale element types to WMMA scale format codes.
735static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
737 .Case([](Float8E8M0FNUType) { return 0; })
738 .Case([](Float8E4M3FNType) { return 2; })
739 .Default(std::nullopt);
740}
741
742/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
743/// and scale block size (16 or 32).
744static std::optional<StringRef>
746 if (m == 16 && n == 16 && k == 128)
747 return isScale16
748 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
749 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
750
751 if (m == 32 && n == 16 && k == 128)
752 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
753 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
754
755 return std::nullopt;
756}
757
758/// Push an input operand. If it is a float type, nothing to do. If it is
759/// an integer type, then we need to also push its signdness (1 for signed, 0
760/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
761/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
762/// We also need to convert bfloat inputs to i16 to account for the bfloat
763/// intrinsics having been defined before the AMD backend supported bfloat. We
764/// similarly need to pack 8-bit float types into integers as if they were i8
765/// (which they are for the backend's purposes).
767 ConversionPatternRewriter &rewriter, Location loc,
768 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
769 Value mlirInput, SmallVectorImpl<Value> &operands,
770 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
771 Type inputType = llvmInput.getType();
772 auto vectorType = dyn_cast<VectorType>(inputType);
773 if (!vectorType) {
774 operands.push_back(llvmInput);
775 return;
776 }
777 Type elemType = vectorType.getElementType();
778 if (elemType.getIntOrFloatBitWidth() > 8) {
779 operands.push_back(llvmInput);
780 return;
781 }
782
783 // We need to check the type of the input before conversion to properly test
784 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
785 // fp8/int8 information is lost during the conversion process.
786 auto mlirInputType = cast<VectorType>(mlirInput.getType());
787 bool isInputInteger = mlirInputType.getElementType().isInteger();
788 if (isInputInteger) {
789 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
790 bool localIsUnsigned = isUnsigned;
791 if (elemType.isUnsignedInteger()) {
792 localIsUnsigned = true;
793 } else if (elemType.isSignedInteger()) {
794 localIsUnsigned = false;
795 }
796 attrs.push_back(
797 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
798 }
799
800 int64_t numBits =
801 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
802 Type i32 = rewriter.getI32Type();
803 Type intrinsicInType = numBits <= 32
804 ? (Type)rewriter.getIntegerType(numBits)
805 : (Type)VectorType::get(numBits / 32, i32);
806 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
807 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
808 loc, llvmIntrinsicInType, llvmInput);
809 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
810 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
811 // Add in the zeros here.
812 if (numBits < 32)
813 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
814 operands.push_back(castInput);
815}
816
817/// Push the output operand. For many cases this is only pushing the output in
818/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
819/// since the same numbers of VGPRs is used, we need to decide if to store the
820/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
821/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
822/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
823/// as the instructions have been changed to return fewer registers instead.
824static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
825 Location loc,
826 const TypeConverter *typeConverter,
827 Value output, int32_t subwordOffset,
828 bool clamp, SmallVectorImpl<Value> &operands,
830 Type inputType = output.getType();
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 Type elemType = vectorType.getElementType();
833 operands.push_back(output);
834 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
835 attrs.push_back(
836 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
837 } else if (elemType.isInteger(32)) {
838 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
839 }
840}
841
842/// Return true if `type` is the E5M2 variant of an 8-bit float that is
843/// supported by the `_bf8` instructions on the given `chipset`.
844static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
845 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
846 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
847}
848
849/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
850/// supported by the `_fp8` instructions on the given `chipset`.
851static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
852 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
853 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
854}
855
856/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
857/// if one exists. This includes checking to ensure the intrinsic is supported
858/// on the architecture you are compiling for.
859static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
860 Chipset chipset) {
861 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
862 b = mfma.getBlocks();
863 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
864 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
865
866 if (sourceElem.isF32() && destElem.isF32()) {
867 if (mfma.getReducePrecision() && chipset >= kGfx942) {
868 if (m == 32 && n == 32 && k == 4 && b == 1)
869 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
870 if (m == 16 && n == 16 && k == 8 && b == 1)
871 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
872 }
873 if (m == 32 && n == 32 && k == 1 && b == 2)
874 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
875 if (m == 16 && n == 16 && k == 1 && b == 4)
876 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
877 if (m == 4 && n == 4 && k == 1 && b == 16)
878 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
879 if (m == 32 && n == 32 && k == 2 && b == 1)
880 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
881 if (m == 16 && n == 16 && k == 4 && b == 1)
882 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
883 }
884
885 if (sourceElem.isF16() && destElem.isF32()) {
886 if (chipset >= kGfx950) {
887 if (m == 32 && n == 32 && k == 16 && b == 1)
888 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 && b == 1)
890 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
891 }
892 if (m == 32 && n == 32 && k == 4 && b == 2)
893 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
894 if (m == 16 && n == 16 && k == 4 && b == 4)
895 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
896 if (m == 4 && n == 4 && k == 4 && b == 16)
897 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
898 if (m == 32 && n == 32 && k == 8 && b == 1)
899 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
900 if (m == 16 && n == 16 && k == 16 && b == 1)
901 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
902 }
903
904 if (sourceElem.isBF16() && destElem.isF32()) {
905 if (chipset >= kGfx950) {
906 if (m == 32 && n == 32 && k == 16 && b == 1)
907 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
908 if (m == 16 && n == 16 && k == 32 && b == 1)
909 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
910 }
911 if (chipset >= kGfx90a) {
912 if (m == 32 && n == 32 && k == 4 && b == 2)
913 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
914 if (m == 16 && n == 16 && k == 4 && b == 4)
915 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
916 if (m == 4 && n == 4 && k == 4 && b == 16)
917 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
918 if (m == 32 && n == 32 && k == 8 && b == 1)
919 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 16 && b == 1)
921 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
922 }
923 if (m == 32 && n == 32 && k == 2 && b == 2)
924 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
925 if (m == 16 && n == 16 && k == 2 && b == 4)
926 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
927 if (m == 4 && n == 4 && k == 2 && b == 16)
928 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
929 if (m == 32 && n == 32 && k == 4 && b == 1)
930 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 8 && b == 1)
932 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
933 }
934
935 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
936 if (chipset >= kGfx950) {
937 if (m == 32 && n == 32 && k == 32 && b == 1)
938 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
939 if (m == 16 && n == 16 && k == 64 && b == 1)
940 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
941 }
942 if (m == 32 && n == 32 && k == 4 && b == 2)
943 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
944 if (m == 16 && n == 16 && k == 4 && b == 4)
945 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
946 if (m == 4 && n == 4 && k == 4 && b == 16)
947 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
948 if (m == 32 && n == 32 && k == 8 && b == 1)
949 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
950 if (m == 16 && n == 16 && k == 16 && b == 1)
951 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
952 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
953 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
954 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
955 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
956 }
957
958 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
959 if (m == 16 && n == 16 && k == 4 && b == 1)
960 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
961 if (m == 4 && n == 4 && k == 4 && b == 4)
962 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
963 }
964
965 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
966 // Known to be correct because there are no scalar f8 instructions and
967 // because a length mismatch will have been caught by the verifier.
968 Type sourceBElem =
969 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
970 if (m == 16 && n == 16 && k == 32 && b == 1) {
971 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
972 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
973 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
974 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
975 }
976 if (m == 32 && n == 32 && k == 16 && b == 1) {
977 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
978 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
979 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
980 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
981 }
982 }
983
984 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
985 Type sourceBElem =
986 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
987 if (m == 16 && n == 16 && k == 32 && b == 1) {
988 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
989 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
990 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
991 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
992 }
993 if (m == 32 && n == 32 && k == 16 && b == 1) {
994 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
995 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
996 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
997 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
998 }
999 }
1000
1001 return std::nullopt;
1002}
1003
1004static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1006 .Case([](Float8E4M3FNType) { return 0u; })
1007 .Case([](Float8E5M2Type) { return 1u; })
1008 .Case([](Float6E2M3FNType) { return 2u; })
1009 .Case([](Float6E3M2FNType) { return 3u; })
1010 .Case([](Float4E2M1FNType) { return 4u; })
1011 .Default(std::nullopt);
1012}
1013
1014/// If there is a scaled MFMA instruction for the input element types `aType`
1015/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1016/// blocks) on the given `chipset`, return a tuple consisting of the
1017/// OperationName of the intrinsic and the type codes that need to be passed to
1018/// that intrinsic. Note that this is also used to implement some un-scaled
1019/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1020/// MFMA with a scale of 0.
1021static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1022mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1023 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1024 aType = getElementTypeOrSelf(aType);
1025 bType = getElementTypeOrSelf(bType);
1026 destType = getElementTypeOrSelf(destType);
1027
1028 if (chipset < kGfx950)
1029 return std::nullopt;
1030 if (!isa<Float32Type>(destType))
1031 return std::nullopt;
1032
1033 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1034 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1035 if (!aTypeCode || !bTypeCode)
1036 return std::nullopt;
1037
1038 if (m == 32 && n == 32 && k == 64 && b == 1)
1039 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1040 *aTypeCode, *bTypeCode};
1041 if (m == 16 && n == 16 && k == 128 && b == 1)
1042 return std::tuple{
1043 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1044 *bTypeCode};
1045
1046 return std::nullopt;
1047}
1048
1049static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1050mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1052 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1053 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1054 mfma.getBlocks(), chipset);
1055}
1056
1057static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1058mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1059 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1060 smfma.getSourceB().getType(),
1061 smfma.getDestC().getType(), smfma.getM(),
1062 smfma.getN(), smfma.getK(), 1u, chipset);
1063}
1064
1065/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1066/// for RDNA3/4 architectures.
1067static std::optional<StringRef>
1068wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1069 Type elemDestType, uint32_t k, bool isRDNA3) {
1070 using fp8 = Float8E4M3FNType;
1071 using bf8 = Float8E5M2Type;
1072
1073 // Handle k == 16 for RDNA3/4.
1074 if (k == 16) {
1075 // Common patterns for RDNA3 and RDNA4.
1076 if (elemSourceType.isF16() && elemDestType.isF32())
1077 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1078 if (elemSourceType.isBF16() && elemDestType.isF32())
1079 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1080 if (elemSourceType.isF16() && elemDestType.isF16())
1081 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1082 if (elemSourceType.isBF16() && elemDestType.isBF16())
1083 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1084 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1085 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1086
1087 // RDNA3 specific patterns.
1088 if (isRDNA3) {
1089 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1090 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1092 }
1093
1094 // RDNA4 specific patterns (fp8/bf8).
1095 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1096 elemDestType.isF32())
1097 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1098 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1099 elemDestType.isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1101 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.isF32())
1103 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1105 elemDestType.isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1107 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1108 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1109
1110 return std::nullopt;
1111 }
1112
1113 // Handle k == 32 for RDNA4.
1114 if (k == 32 && !isRDNA3) {
1115 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1116 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1117 }
1118
1119 return std::nullopt;
1120}
1121
1122/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1123/// for the gfx1250 architecture.
1124static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1125 Type elemBSourceType,
1126 Type elemDestType,
1127 uint32_t k) {
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1130
1131 if (k == 4) {
1132 if (elemSourceType.isF32() && elemDestType.isF32())
1133 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1134
1135 return std::nullopt;
1136 }
1137
1138 if (k == 32) {
1139 if (elemSourceType.isF16() && elemDestType.isF32())
1140 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1141 if (elemSourceType.isBF16() && elemDestType.isF32())
1142 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1143 if (elemSourceType.isF16() && elemDestType.isF16())
1144 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1145 if (elemSourceType.isBF16() && elemDestType.isBF16())
1146 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1147
1148 return std::nullopt;
1149 }
1150
1151 if (k == 64) {
1152 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1153 if (elemDestType.isF32())
1154 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1155 if (elemDestType.isF16())
1156 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1157 }
1158 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1159 if (elemDestType.isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1161 if (elemDestType.isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1163 }
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.isF32())
1166 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1167 if (elemDestType.isF16())
1168 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1169 }
1170 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1171 if (elemDestType.isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1173 if (elemDestType.isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1175 }
1176 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1177 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1178
1179 return std::nullopt;
1180 }
1181
1182 if (k == 128) {
1183 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1184 if (elemDestType.isF32())
1185 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1186 if (elemDestType.isF16())
1187 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1188 }
1189 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1190 if (elemDestType.isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1192 if (elemDestType.isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1194 }
1195 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.isF32())
1197 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1198 if (elemDestType.isF16())
1199 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1200 }
1201 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1202 if (elemDestType.isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1204 if (elemDestType.isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1206 }
1207
1208 return std::nullopt;
1209 }
1210
1211 return std::nullopt;
1212}
1213
1214/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1215/// operation if one exists. This includes checking to ensure the intrinsic is
1216/// supported on the architecture you are compiling for.
1217static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1218 Chipset chipset) {
1219 bool isGfx950 = chipset >= kGfx950;
1220 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1221 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1222
1223 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1224 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1225 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1226 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1227
1228 if (m == 16 && n == 16 && k == 32) {
1229 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1230 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1231 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1232 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1233 }
1234
1235 if (m == 16 && n == 16 && k == 64) {
1236 if (isGfx950) {
1237 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1238 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1239 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1240 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1241 }
1242 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1243 destElem.isInteger(32))
1244 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1245 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1246 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1247 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1248 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1249 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1250 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1251 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1252 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1253 }
1254
1255 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1256 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1257 destElem.isInteger(32))
1258 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1259 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1260 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1261 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1262 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1263 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1264 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1265 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1266 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1267 }
1268
1269 if (m == 32 && n == 32 && k == 16) {
1270 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1271 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1272 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1273 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1274 }
1275
1276 if (m == 32 && n == 32 && k == 32) {
1277 if (isGfx950) {
1278 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1279 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1280 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1281 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1282 }
1283 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1284 destElem.isInteger(32))
1285 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1286 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1287 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1288 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1289 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1290 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1291 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1292 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1293 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1294 }
1295
1296 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1297 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1298 destElem.isInteger(32))
1299 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1300 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1301 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1302 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1303 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1304 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1305 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1306 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1307 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1308 }
1309
1310 return std::nullopt;
1311}
1312
1313/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1314/// if one exists. This includes checking to ensure the intrinsic is supported
1315/// on the architecture you are compiling for.
1316static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1317 Chipset chipset) {
1318 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1319 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1320 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1321 Type elemSourceType = sourceVectorType.getElementType();
1322 Type elemBSourceType = sourceBVectorType.getElementType();
1323 Type elemDestType = destVectorType.getElementType();
1324
1325 const uint32_t k = wmma.getK();
1326 const bool isRDNA3 = chipset.majorVersion == 11;
1327 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1328
1329 // Handle RDNA3 and RDNA4.
1330 if (isRDNA3 || isRDNA4)
1331 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1332 k, isRDNA3);
1333
1334 // Handle gfx1250.
1335 if (chipset == kGfx1250)
1336 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1337 elemDestType, k);
1338
1339 return std::nullopt;
1340}
1341
1342namespace {
1343struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1344 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1345 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1346
1347 Chipset chipset;
1348
1349 LogicalResult
1350 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1351 ConversionPatternRewriter &rewriter) const override {
1352 Location loc = op.getLoc();
1353 Type outType = typeConverter->convertType(op.getDestD().getType());
1354 Type intrinsicOutType = outType;
1355 if (auto outVecType = dyn_cast<VectorType>(outType))
1356 if (outVecType.getElementType().isBF16())
1357 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1358
1359 if (chipset.majorVersion != 9 || chipset < kGfx908)
1360 return op->emitOpError("MFMA only supported on gfx908+");
1361 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1362 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1363 if (chipset < kGfx942)
1364 return op.emitOpError("negation unsupported on older than gfx942");
1365 getBlgpField |=
1366 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1367 }
1368 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1369 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1370 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1371 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1372 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1373
1374 bool isScaled =
1375 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1376 if (isScaled &&
1377 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1378 return op.emitOpError(
1379 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1380 "be scaled as those fields are used for type information");
1381 }
1382
1383 StringRef intrinsicName =
1384 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1385 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1386 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1387 bool allowBf16 = [&]() {
1388 if (chipset < kGfx950)
1389 return false;
1390 if (isScaled)
1391 return true;
1392 return intrinsicName.contains("16x16x32.bf16") ||
1393 intrinsicName.contains("32x32x16.bf16");
1394 }();
1395 OperationState loweredOp(loc, intrinsicName);
1396 loweredOp.addTypes(intrinsicOutType);
1397 loweredOp.addOperands({packSmallFloatVectorOperand(
1398 rewriter, loc, adaptor.getSourceA(), allowBf16),
1400 rewriter, loc, adaptor.getSourceB(), allowBf16),
1401 adaptor.getDestC()});
1402 if (isScaled) {
1403 Value zero = createI32Constant(rewriter, loc, 0);
1404 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1405 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1406 createI32Constant(rewriter, loc, bTypeCode),
1407 /*scale A byte=*/zero, /*scale A=*/zero,
1408 /*scale B byte=*/zero, /*scale B=*/zero});
1409 } else {
1410 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1411 createI32Constant(rewriter, loc, op.getAbid()),
1412 createI32Constant(rewriter, loc, getBlgpField)});
1413 };
1414 Value lowered = rewriter.create(loweredOp)->getResult(0);
1415 if (outType != intrinsicOutType)
1416 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1417 rewriter.replaceOp(op, lowered);
1418 return success();
1419 }
1420};
1421
1422struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1423 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1424 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1425
1426 Chipset chipset;
1427
1428 LogicalResult
1429 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1430 ConversionPatternRewriter &rewriter) const override {
1431 Location loc = op.getLoc();
1432 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1433
1434 if (chipset.majorVersion != 9 || chipset < kGfx950)
1435 return op->emitOpError("scaled MFMA only supported on gfx908+");
1436 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1437 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1438 if (!maybeScaledIntrinsic.has_value())
1439 return op.emitOpError(
1440 "no intrinsic matching scaled MFMA size on given chipset");
1441
1442 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1443 OperationState loweredOp(loc, intrinsicName);
1444 loweredOp.addTypes(intrinsicOutType);
1445 loweredOp.addOperands(
1446 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1447 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1448 adaptor.getDestC()});
1449 Value scalesIdxA =
1450 createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1451 Value scalesIdxB =
1452 createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1453 loweredOp.addOperands(
1454 {createI32Constant(rewriter, loc, aTypeCode),
1455 createI32Constant(rewriter, loc, bTypeCode),
1456 /*scales idx A=*/scalesIdxA,
1457 /*scales A*/
1458 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1459 /*scales idx B=*/scalesIdxB,
1460 /*scales B*/
1461 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1462 Value lowered = rewriter.create(loweredOp)->getResult(0);
1463 rewriter.replaceOp(op, lowered);
1464 return success();
1465 }
1466};
1467
1468struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1469 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1470 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1471
1472 Chipset chipset;
1473
1474 LogicalResult
1475 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1476 ConversionPatternRewriter &rewriter) const override {
1477 Location loc = op.getLoc();
1478 auto outType =
1479 typeConverter->convertType<VectorType>(op.getDestC().getType());
1480 if (!outType)
1481 return rewriter.notifyMatchFailure(op, "type conversion failed");
1482
1483 // smfmac is supported on gfx942 and gfx950.
1484 if (chipset.majorVersion != 9 || chipset < kGfx942)
1485 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1486 bool isGfx950 = chipset >= kGfx950;
1487
1488 Value a = convertSparseMFMAVectorOperand(rewriter, loc,
1489 adaptor.getSourceA(), isGfx950);
1490 Value b = convertSparseMFMAVectorOperand(rewriter, loc,
1491 adaptor.getSourceB(), isGfx950);
1492 Value c = adaptor.getDestC();
1493
1494 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1495 if (!maybeIntrinsic.has_value())
1496 return op.emitOpError(
1497 "no intrinsic matching sparse MFMA on the given chipset");
1498
1499 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1500 Value sparseIdx = LLVM::BitcastOp::create(
1501 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1502
1503 OperationState loweredOp(loc, maybeIntrinsic.value());
1504 loweredOp.addTypes(outType);
1505 loweredOp.addOperands({a, b, c, sparseIdx,
1506 createI32Constant(rewriter, loc, op.getCbsz()),
1507 createI32Constant(rewriter, loc, op.getAbid())});
1508 Value lowered = rewriter.create(loweredOp)->getResult(0);
1509 rewriter.replaceOp(op, lowered);
1510 return success();
1511 }
1512};
1513
1514struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1515 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1516 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1517
1518 Chipset chipset;
1519
1520 LogicalResult
1521 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1522 ConversionPatternRewriter &rewriter) const override {
1523 Location loc = op.getLoc();
1524 auto outType =
1525 typeConverter->convertType<VectorType>(op.getDestD().getType());
1526 if (!outType)
1527 return rewriter.notifyMatchFailure(op, "type conversion failed");
1528
1529 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1530 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1531
1532 bool isGFX1250 = chipset >= kGfx1250;
1533
1534 // The WMMA operations represent vectors of bf16s as vectors of i16s
1535 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1536 // bitcast them back.
1537 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1538 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1539 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1540 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1541 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1542 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1543 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1544 VectorType rawOutType = outType;
1545 if (castOutToI16)
1546 rawOutType = outType.clone(rewriter.getI16Type());
1547 Value a = adaptor.getSourceA();
1548 if (castAToI16)
1549 a = LLVM::BitcastOp::create(rewriter, loc,
1550 aType.clone(rewriter.getI16Type()), a);
1551 Value b = adaptor.getSourceB();
1552 if (castBToI16)
1553 b = LLVM::BitcastOp::create(rewriter, loc,
1554 bType.clone(rewriter.getI16Type()), b);
1555 Value destC = adaptor.getDestC();
1556 if (castDestCToI16)
1557 destC = LLVM::BitcastOp::create(
1558 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1559
1560 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1561
1562 if (!maybeIntrinsic.has_value())
1563 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1564
1565 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1566 return op.emitOpError("subwordOffset not supported on gfx12+");
1567
1568 SmallVector<Value, 4> operands;
1569 SmallVector<NamedAttribute, 4> attrs;
1570 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1571 op.getSourceA(), operands, attrs, "signA");
1572 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1573 op.getSourceB(), operands, attrs, "signB");
1574 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1575 op.getSubwordOffset(), op.getClamp(), operands,
1576 attrs);
1577
1578 OperationState loweredOp(loc, *maybeIntrinsic);
1579 loweredOp.addTypes(rawOutType);
1580 loweredOp.addOperands(operands);
1581 loweredOp.addAttributes(attrs);
1582 Operation *lowered = rewriter.create(loweredOp);
1583
1584 Operation *maybeCastBack = lowered;
1585 if (rawOutType != outType)
1586 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1587 lowered->getResult(0));
1588 rewriter.replaceOp(op, maybeCastBack->getResults());
1589
1590 return success();
1591 }
1592};
1593
1594struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1595 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1596 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1597
1598 Chipset chipset;
1599
1600 LogicalResult
1601 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1602 ConversionPatternRewriter &rewriter) const override {
1603 Location loc = op.getLoc();
1604 auto outType =
1605 typeConverter->convertType<VectorType>(op.getDestD().getType());
1606 if (!outType)
1607 return rewriter.notifyMatchFailure(op, "type conversion failed");
1608
1609 if (chipset < kGfx1250)
1610 return op->emitOpError("WMMA scale only supported on gfx1250+");
1611
1612 int64_t m = op.getM();
1613 int64_t n = op.getN();
1614 int64_t k = op.getK();
1615
1616 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1617 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1618
1619 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1620 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1621
1622 if (!aFmtCode || !bFmtCode)
1623 return op.emitOpError("unsupported element types for scaled_wmma");
1624
1625 // Get scale vector types and determine variant (scale vs scale16).
1626 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1627 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1628
1629 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1630 return op.emitOpError("scaleA and scaleB must have equal vector length");
1631
1632 // Extract scale format from element types.
1633 Type scaleAElemType = scaleAVecType.getElementType();
1634 Type scaleBElemType = scaleBVecType.getElementType();
1635
1636 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1637 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1638
1639 if (!scaleAFmt || !scaleBFmt)
1640 return op.emitOpError("unsupported scale element types");
1641
1642 // Determine which intrinsic to use based on dimensions.
1643 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1644 std::optional<StringRef> intrinsicName =
1645 getScaledWmmaIntrinsicName(m, n, k, isScale16);
1646 if (!intrinsicName)
1647 return op.emitOpError("unsupported scaled_wmma dimensions: ")
1648 << m << "x" << n << "x" << k;
1649
1650 SmallVector<NamedAttribute, 8> attrs;
1651
1652 // The f4 variant does not have fmtA and fmtB attributes.
1653 bool is32x16 = (m == 32 && n == 16 && k == 128);
1654 if (!is32x16) {
1655 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1656 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1657 }
1658
1659 // modC uses default value of 0.
1660 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1661
1662 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1663 // half of the wave that is being selected (0 or 1).
1664 attrs.emplace_back(
1665 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1666 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1667 attrs.emplace_back(
1668 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1669 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1670
1671 // Reuse flags use default value of false.
1672 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1673 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1674
1675 // Convert typed float vectors to packed format.
1676 Value sourceA =
1677 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1678 Value sourceB =
1679 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1680
1681 // Pack scale vectors into i32/i64.
1682 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1683 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1684
1685 // Create the intrinsic call.
1686 OperationState loweredOp(loc, *intrinsicName);
1687 loweredOp.addTypes(outType);
1688 loweredOp.addOperands(
1689 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1690 loweredOp.addAttributes(attrs);
1691
1692 Operation *lowered = rewriter.create(loweredOp);
1693 rewriter.replaceOp(op, lowered->getResults());
1694
1695 return success();
1696 }
1697};
1698
1699struct TransposeLoadOpLowering
1700 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1701 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1702 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1703
1704 Chipset chipset;
1705
1706 LogicalResult
1707 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1708 ConversionPatternRewriter &rewriter) const override {
1709 if (chipset != kGfx950)
1710 return op.emitOpError("Non-gfx950 chipset not supported");
1711
1712 Location loc = op.getLoc();
1713 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1714
1715 // Elements in subbyte memrefs are stored non-contiguously,
1716 // reject if source is sub-byte memref. Use emulated memrefs instead.
1717 size_t srcElementSize =
1718 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1719 if (srcElementSize < 8)
1720 return op.emitOpError("Expect source memref to have at least 8 bits "
1721 "element size, got ")
1722 << srcElementSize;
1723
1724 auto resultType = cast<VectorType>(op.getResult().getType());
1725 Value srcPtr =
1726 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1727 (adaptor.getSrcIndices()));
1728
1729 size_t numElements = resultType.getNumElements();
1730 size_t elementTypeSize =
1731 resultType.getElementType().getIntOrFloatBitWidth();
1732
1733 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1734 // the element size is smaller than 16 bits.
1735 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1736 rewriter.getIntegerType(32));
1737 Type llvmResultType = typeConverter->convertType(resultType);
1738
1739 switch (elementTypeSize) {
1740 case 4: {
1741 assert(numElements == 16);
1742 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1743 rocdlResultType, srcPtr);
1744 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1745 break;
1746 }
1747 case 6: {
1748 assert(numElements == 16);
1749 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1750 rocdlResultType, srcPtr);
1751 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1752 break;
1753 }
1754 case 8: {
1755 assert(numElements == 8);
1756 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1757 rocdlResultType, srcPtr);
1758 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1759 break;
1760 }
1761 case 16: {
1762 assert(numElements == 4);
1763 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1764 srcPtr);
1765 break;
1766 }
1767 default:
1768 return op.emitOpError("Unsupported element size for transpose load");
1769 }
1770 return success();
1771 }
1772};
1773
1774struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1775 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1776 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1777
1778 Chipset chipset;
1779
1780 LogicalResult
1781 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1782 ConversionPatternRewriter &rewriter) const override {
1783 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1784 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1785
1786 Location loc = op.getLoc();
1787
1788 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1789 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1790
1791 // TODO: instead of only transfering one element per thread, we could
1792 // augment it to transfer multiple elements per thread by issuing multiple
1793 // `global_load_lds` instructions.
1794 Type transferType = op.getTransferType();
1795 int loadWidth = [&]() -> int {
1796 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1797 return (transferVectorType.getNumElements() *
1798 transferVectorType.getElementTypeBitWidth()) /
1799 8;
1800 }
1801 return transferType.getIntOrFloatBitWidth() / 8;
1802 }();
1803
1804 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1805 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1806 return op.emitOpError("chipset unsupported element size");
1807
1808 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1809 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1810 "16-byte load widths are only supported on gfx950");
1811
1812 Value srcPtr =
1813 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1814 (adaptor.getSrcIndices()));
1815 Value dstPtr =
1816 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1817 (adaptor.getDstIndices()));
1818
1819 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1820 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1821 /*offset=*/rewriter.getI32IntegerAttr(0),
1822 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1823 ArrayAttr{});
1824
1825 return success();
1826 }
1827};
1828
1829namespace {
1830struct ExtPackedFp8OpLowering final
1831 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1832 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1833 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1834 chipset(chipset) {}
1835 Chipset chipset;
1836
1837 LogicalResult
1838 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1839 ConversionPatternRewriter &rewriter) const override;
1840};
1841
1842struct ScaledExtPackedMatrixOpLowering final
1843 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1844 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1845 Chipset chipset)
1846 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1847 chipset(chipset) {}
1848 Chipset chipset;
1849
1850 LogicalResult
1851 matchAndRewrite(ScaledExtPackedMatrixOp op,
1852 ScaledExtPackedMatrixOpAdaptor adaptor,
1853 ConversionPatternRewriter &rewriter) const override;
1854};
1855
1856struct PackedTrunc2xFp8OpLowering final
1857 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1858 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1859 Chipset chipset)
1860 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1861 chipset(chipset) {}
1862 Chipset chipset;
1863
1864 LogicalResult
1865 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1866 ConversionPatternRewriter &rewriter) const override;
1867};
1868
1869struct PackedStochRoundFp8OpLowering final
1870 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1871 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1872 Chipset chipset)
1873 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1874 chipset(chipset) {}
1875 Chipset chipset;
1876
1877 LogicalResult
1878 matchAndRewrite(PackedStochRoundFp8Op op,
1879 PackedStochRoundFp8OpAdaptor adaptor,
1880 ConversionPatternRewriter &rewriter) const override;
1881};
1882
1883struct ScaledExtPackedOpLowering final
1884 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1885 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1886 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1887 chipset(chipset) {}
1888 Chipset chipset;
1889
1890 LogicalResult
1891 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1892 ConversionPatternRewriter &rewriter) const override;
1893};
1894
1895struct PackedScaledTruncOpLowering final
1896 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1897 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1898 Chipset chipset)
1899 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1900 chipset(chipset) {}
1901 Chipset chipset;
1902
1903 LogicalResult
1904 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1905 ConversionPatternRewriter &rewriter) const override;
1906};
1907
1908} // end namespace
1909
1910LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1911 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1912 ConversionPatternRewriter &rewriter) const {
1913 Location loc = op.getLoc();
1914 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1915 return rewriter.notifyMatchFailure(
1916 loc, "Fp8 conversion instructions are not available on target "
1917 "architecture and their emulation is not implemented");
1918 Type v4i8 =
1919 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1920 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1921 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1922
1923 Value source = adaptor.getSource();
1924 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1925 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1926 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1927 // Extend to a v4i8
1928 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1929 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1930 if (!sourceVecType) {
1931 longVec = LLVM::InsertElementOp::create(
1932 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1933 } else {
1934 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1935 Value idx = createI32Constant(rewriter, loc, i);
1936 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1937 longVec =
1938 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1939 }
1940 }
1941 source = longVec;
1942 }
1943 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1944 if (resultVecType) {
1945 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1946 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1947 op.getIndex());
1948 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1949 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1950 op.getIndex());
1951 }
1952 } else {
1953 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1954 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1955 op.getIndex());
1956 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1957 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1958 op.getIndex());
1959 }
1960 }
1961 return success();
1962}
1963
1964int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1965 int32_t firstScaleByte) {
1966 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1967 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
1968 // firstScaleByte are merged into a single attribute scaleSel. This is how
1969 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1970 // attribute but is derifed from firstScaleLane).
1971 assert(llvm::is_contained({16, 32}, blockSize));
1972 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1973
1974 const bool isFp8 = bitWidth == 8;
1975 const bool isBlock16 = blockSize == 16;
1976
1977 if (!isFp8) {
1978 int32_t bit0 = isBlock16;
1979 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1980 int32_t bit1 = (firstScaleByte == 2) << 1;
1981 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1982 int32_t bit2 = scaleWaveHalf << 2;
1983 return bit2 | bit1 | bit0;
1984 }
1985
1986 int32_t bit0 = isBlock16;
1987 // firstScaleByte is guaranteed to be defined by two bits.
1988 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1989 int32_t bits2and1 = firstScaleByte << 1;
1990 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1991 int32_t bit3 = scaleWaveHalf << 3;
1992 int32_t bits = bit3 | bits2and1 | bit0;
1993 // These are invalid cases.
1994 assert(!llvm::is_contained(
1995 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1996 return bits;
1997}
1998
1999static std::optional<StringRef>
2000scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2001 using fp4 = Float4E2M1FNType;
2002 using fp8 = Float8E4M3FNType;
2003 using bf8 = Float8E5M2Type;
2004 using fp6 = Float6E2M3FNType;
2005 using bf6 = Float6E3M2FNType;
2006 if (isa<fp4>(srcElemType)) {
2007 if (destElemType.isF16())
2008 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2009 if (destElemType.isBF16())
2010 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2011 if (destElemType.isF32())
2012 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2013 return std::nullopt;
2014 }
2015 if (isa<fp8>(srcElemType)) {
2016 if (destElemType.isF16())
2017 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2018 if (destElemType.isBF16())
2019 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2020 if (destElemType.isF32())
2021 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2022 return std::nullopt;
2023 }
2024 if (isa<bf8>(srcElemType)) {
2025 if (destElemType.isF16())
2026 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2027 if (destElemType.isBF16())
2028 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2029 if (destElemType.isF32())
2030 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2031 return std::nullopt;
2032 }
2033 if (isa<fp6>(srcElemType)) {
2034 if (destElemType.isF16())
2035 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2036 if (destElemType.isBF16())
2037 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2038 if (destElemType.isF32())
2039 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2040 return std::nullopt;
2041 }
2042 if (isa<bf6>(srcElemType)) {
2043 if (destElemType.isF16())
2044 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2045 if (destElemType.isBF16())
2046 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2047 if (destElemType.isF32())
2048 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2049 return std::nullopt;
2050 }
2051 llvm_unreachable("invalid combination of element types for packed conversion "
2052 "instructions");
2053}
2054
2055LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2056 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2057 ConversionPatternRewriter &rewriter) const {
2058 using fp4 = Float4E2M1FNType;
2059 using fp8 = Float8E4M3FNType;
2060 using bf8 = Float8E5M2Type;
2061 using fp6 = Float6E2M3FNType;
2062 using bf6 = Float6E3M2FNType;
2063 Location loc = op.getLoc();
2064 if (chipset != kGfx1250) {
2065 return rewriter.notifyMatchFailure(
2066 loc,
2067 "Scaled fp packed conversion instructions are not available on target "
2068 "architecture and their emulation is not implemented");
2069 }
2070 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2071 // is being selected.
2072 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2073 int32_t firstScaleByte = op.getFirstScaleByte();
2074 int32_t blockSize = op.getBlockSize();
2075 auto sourceType = cast<VectorType>(op.getSource().getType());
2076 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2077 unsigned bitWidth = srcElemType.getWidth();
2078
2079 auto targetType = cast<VectorType>(op.getResult().getType());
2080 auto destElemType = cast<FloatType>(targetType.getElementType());
2081
2082 IntegerType i32 = rewriter.getI32Type();
2083 Value source = adaptor.getSource();
2084 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2085 Type packedType = nullptr;
2086 if (isa<fp4>(srcElemType)) {
2087 packedType = i32;
2088 packedType = getTypeConverter()->convertType(packedType);
2089 } else if (isa<fp8, bf8>(srcElemType)) {
2090 packedType = VectorType::get(2, i32);
2091 packedType = getTypeConverter()->convertType(packedType);
2092 } else if (isa<fp6, bf6>(srcElemType)) {
2093 packedType = VectorType::get(3, i32);
2094 packedType = getTypeConverter()->convertType(packedType);
2095 } else {
2096 llvm_unreachable("invalid element type for packed scaled ext");
2097 }
2098
2099 if (!packedType || !llvmResultType) {
2100 return rewriter.notifyMatchFailure(op, "type conversion failed");
2101 }
2102
2103 std::optional<StringRef> maybeIntrinsic =
2104 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2105 if (!maybeIntrinsic.has_value())
2106 return op.emitOpError(
2107 "no intrinsic matching packed scaled conversion on the given chipset");
2108
2109 int32_t scaleSel =
2110 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2111 Value castedScale =
2112 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2113 Value castedSource =
2114 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2115
2116 OperationState loweredOp(loc, *maybeIntrinsic);
2117 loweredOp.addTypes({llvmResultType});
2118 loweredOp.addOperands({castedSource, castedScale});
2119
2120 SmallVector<NamedAttribute, 1> attrs;
2121 attrs.push_back(
2122 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2123
2124 loweredOp.addAttributes(attrs);
2125 Operation *lowered = rewriter.create(loweredOp);
2126 rewriter.replaceOp(op, lowered);
2127
2128 return success();
2129}
2130
2131LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2132 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2133 ConversionPatternRewriter &rewriter) const {
2134 Location loc = op.getLoc();
2135 if (chipset != kGfx950)
2136 return rewriter.notifyMatchFailure(
2137 loc, "Scaled fp conversion instructions are not available on target "
2138 "architecture and their emulation is not implemented");
2139 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2140
2141 Value source = adaptor.getSource();
2142 Value scale = adaptor.getScale();
2143
2144 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2145 Type sourceElemType = sourceVecType.getElementType();
2146 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2147 Type destElemType = destVecType.getElementType();
2148
2149 VectorType packedVecType;
2150 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2151 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2152 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2153 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2154 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2155 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2156 } else {
2157 llvm_unreachable("invalid element type for scaled ext");
2158 }
2159
2160 // Extend to a packedVectorType
2161 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2162 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2163 if (!sourceVecType) {
2164 longVec = LLVM::InsertElementOp::create(
2165 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2166 } else {
2167 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2168 Value idx = createI32Constant(rewriter, loc, i);
2169 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2170 longVec =
2171 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2172 }
2173 }
2174 source = longVec;
2175 }
2176 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2177
2178 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2179 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2180 op, destVecType, i32Source, scale, op.getIndex());
2181 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2182 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2183 op, destVecType, i32Source, scale, op.getIndex());
2184 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2185 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2186 op, destVecType, i32Source, scale, op.getIndex());
2187 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2188 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2189 op, destVecType, i32Source, scale, op.getIndex());
2190 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2191 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2192 op, destVecType, i32Source, scale, op.getIndex());
2193 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2194 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2195 op, destVecType, i32Source, scale, op.getIndex());
2196 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2197 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2198 op, destVecType, i32Source, scale, op.getIndex());
2199 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2200 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2201 op, destVecType, i32Source, scale, op.getIndex());
2202 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2203 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2204 op, destVecType, i32Source, scale, op.getIndex());
2205 else
2206 return failure();
2207
2208 return success();
2209}
2210
2211LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2212 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2213 ConversionPatternRewriter &rewriter) const {
2214 Location loc = op.getLoc();
2215 if (chipset != kGfx950)
2216 return rewriter.notifyMatchFailure(
2217 loc, "Scaled fp conversion instructions are not available on target "
2218 "architecture and their emulation is not implemented");
2219 Type v2i16 = getTypeConverter()->convertType(
2220 VectorType::get(2, rewriter.getI16Type()));
2221 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2222
2223 Type resultType = op.getResult().getType();
2224 Type resultElemType = getElementTypeOrSelf(resultType);
2225 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2226 Type sourceElemType = sourceVecType.getElementType();
2227
2228 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2229
2230 Value source = adaptor.getSource();
2231 Value scale = adaptor.getScale();
2232 Value existing = adaptor.getExisting();
2233 if (existing)
2234 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2235 else
2236 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2237
2238 if (sourceVecType.getNumElements() < 2) {
2239 Value c0 = createI32Constant(rewriter, loc, 0);
2240 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2241 VectorType v2 = VectorType::get(2, sourceElemType);
2242 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2243 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2244 }
2245
2246 Value sourceA, sourceB;
2247 if (sourceElemType.isF32()) {
2248 Value c0 = createI32Constant(rewriter, loc, 0);
2249 Value c1 = createI32Constant(rewriter, loc, 1);
2250 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2251 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2252 }
2253
2254 Value result;
2255 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2256 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2257 existing, sourceA, sourceB,
2258 scale, op.getIndex());
2259 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2260 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2261 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2262 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2263 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2264 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2265 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2266 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2267 existing, sourceA, sourceB,
2268 scale, op.getIndex());
2269 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2270 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2271 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2272 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2273 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2274 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2275 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2276 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2277 existing, sourceA, sourceB,
2278 scale, op.getIndex());
2279 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2280 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2281 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2282 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2283 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2284 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2285 else
2286 return failure();
2287
2288 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2289 op, getTypeConverter()->convertType(resultType), result);
2290 return success();
2291}
2292
2293LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2294 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2295 ConversionPatternRewriter &rewriter) const {
2296 Location loc = op.getLoc();
2297 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2298 return rewriter.notifyMatchFailure(
2299 loc, "Fp8 conversion instructions are not available on target "
2300 "architecture and their emulation is not implemented");
2301 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2302
2303 Type resultType = op.getResult().getType();
2304 Type resultElemType = getElementTypeOrSelf(resultType);
2305
2306 Value sourceA = adaptor.getSourceA();
2307 Value sourceB = adaptor.getSourceB();
2308 if (!sourceB)
2309 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2310 Value existing = adaptor.getExisting();
2311 if (existing)
2312 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2313 else
2314 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2315
2316 Value result;
2317 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2318 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2319 existing, op.getWordIndex());
2320 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2321 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2322 existing, op.getWordIndex());
2323
2324 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2325 op, getTypeConverter()->convertType(resultType), result);
2326 return success();
2327}
2328
2329LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2330 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2331 ConversionPatternRewriter &rewriter) const {
2332 Location loc = op.getLoc();
2333 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2334 return rewriter.notifyMatchFailure(
2335 loc, "Fp8 conversion instructions are not available on target "
2336 "architecture and their emulation is not implemented");
2337 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2338
2339 Type resultType = op.getResult().getType();
2340 Type resultElemType = getElementTypeOrSelf(resultType);
2341
2342 Value source = adaptor.getSource();
2343 Value stoch = adaptor.getStochiasticParam();
2344 Value existing = adaptor.getExisting();
2345 if (existing)
2346 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2347 else
2348 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2349
2350 Value result;
2351 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2352 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2353 existing, op.getStoreIndex());
2354 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2355 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2356 existing, op.getStoreIndex());
2357
2358 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2359 op, getTypeConverter()->convertType(resultType), result);
2360 return success();
2361}
2362
2363// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2364// operation into the corresponding ROCDL instructions.
2365struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2366 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2367 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2368 Chipset chipset;
2369
2370 LogicalResult
2371 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2372 ConversionPatternRewriter &rewriter) const override {
2373
2374 // Convert the source operand to the corresponding LLVM type
2375 Location loc = DppOp.getLoc();
2376 Value src = adaptor.getSrc();
2377 Value old = adaptor.getOld();
2378 Type srcType = src.getType();
2379 Type oldType = old.getType();
2380 Type llvmType = nullptr;
2381 if (srcType.getIntOrFloatBitWidth() < 32) {
2382 llvmType = rewriter.getI32Type();
2383 } else if (isa<FloatType>(srcType)) {
2384 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2385 ? rewriter.getF32Type()
2386 : rewriter.getF64Type();
2387 } else if (isa<IntegerType>(srcType)) {
2388 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2389 ? rewriter.getI32Type()
2390 : rewriter.getI64Type();
2391 }
2392 auto llvmSrcIntType = typeConverter->convertType(
2393 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2394
2395 // If the source type is less of 32, use bitcast to convert it to i32.
2396 auto convertOperand = [&](Value operand, Type operandType) {
2397 if (operandType.getIntOrFloatBitWidth() <= 16) {
2398 if (llvm::isa<FloatType>(operandType)) {
2399 operand =
2400 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2401 }
2402 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2403 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2404 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2405 operand =
2406 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2407 createI32Constant(rewriter, loc, 0));
2408 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2409 }
2410 return operand;
2411 };
2412
2413 src = convertOperand(src, srcType);
2414 old = convertOperand(old, oldType);
2415
2416 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2417 enum DppCtrl : unsigned {
2418 ROW_SHL0 = 0x100,
2419 ROW_SHR0 = 0x110,
2420 ROW_ROR0 = 0x120,
2421 WAVE_SHL1 = 0x130,
2422 WAVE_ROL1 = 0x134,
2423 WAVE_SHR1 = 0x138,
2424 WAVE_ROR1 = 0x13C,
2425 ROW_MIRROR = 0x140,
2426 ROW_HALF_MIRROR = 0x141,
2427 BCAST15 = 0x142,
2428 BCAST31 = 0x143,
2429 };
2430
2431 auto kind = DppOp.getKind();
2432 auto permArgument = DppOp.getPermArgument();
2433 uint32_t DppCtrl = 0;
2434
2435 switch (kind) {
2436
2437 case DPPPerm::quad_perm: {
2438 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2439 int32_t i = 0;
2440 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2441 uint32_t num = elem.getInt();
2442 DppCtrl |= num << (i * 2);
2443 i++;
2444 }
2445 break;
2446 }
2447 case DPPPerm::row_shl: {
2448 auto intAttr = cast<IntegerAttr>(*permArgument);
2449 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2450 break;
2451 }
2452 case DPPPerm::row_shr: {
2453 auto intAttr = cast<IntegerAttr>(*permArgument);
2454 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2455 break;
2456 }
2457 case DPPPerm::row_ror: {
2458 auto intAttr = cast<IntegerAttr>(*permArgument);
2459 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2460 break;
2461 }
2462 case DPPPerm::wave_shl:
2463 DppCtrl = DppCtrl::WAVE_SHL1;
2464 break;
2465 case DPPPerm::wave_shr:
2466 DppCtrl = DppCtrl::WAVE_SHR1;
2467 break;
2468 case DPPPerm::wave_rol:
2469 DppCtrl = DppCtrl::WAVE_ROL1;
2470 break;
2471 case DPPPerm::wave_ror:
2472 DppCtrl = DppCtrl::WAVE_ROR1;
2473 break;
2474 case DPPPerm::row_mirror:
2475 DppCtrl = DppCtrl::ROW_MIRROR;
2476 break;
2477 case DPPPerm::row_half_mirror:
2478 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2479 break;
2480 case DPPPerm::row_bcast_15:
2481 DppCtrl = DppCtrl::BCAST15;
2482 break;
2483 case DPPPerm::row_bcast_31:
2484 DppCtrl = DppCtrl::BCAST31;
2485 break;
2486 }
2487
2488 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2489 // constants
2490 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2491 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2492 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2493
2494 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2495 auto dppMovOp =
2496 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2497 rowMask, bankMask, boundCtrl);
2498
2499 Value result = dppMovOp.getRes();
2500 if (srcType.getIntOrFloatBitWidth() < 32) {
2501 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2502 if (!llvm::isa<IntegerType>(srcType)) {
2503 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2504 }
2505 }
2506
2507 // We are replacing the AMDGPU_DPPOp instruction with the new
2508 // ROCDL_DPPMovOp instruction
2509 rewriter.replaceOp(DppOp, ValueRange(result));
2510 return success();
2511 }
2512};
2513
2514struct AMDGPUSwizzleBitModeLowering
2515 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2517
2518 LogicalResult
2519 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2520 ConversionPatternRewriter &rewriter) const override {
2521 Location loc = op.getLoc();
2522 Type i32 = rewriter.getI32Type();
2523 Value src = adaptor.getSrc();
2524 SmallVector<Value> decomposed =
2525 LLVM::decomposeValue(rewriter, loc, src, i32);
2526 unsigned andMask = op.getAndMask();
2527 unsigned orMask = op.getOrMask();
2528 unsigned xorMask = op.getXorMask();
2529
2530 // bit 15 is 0 for the BitMode swizzle.
2531 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2532 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2533 Value maskValue = createI32Constant(rewriter, loc, mask);
2534 SmallVector<Value> swizzled;
2535 for (Value v : decomposed) {
2536 Value res =
2537 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2538 swizzled.emplace_back(res);
2539 }
2540
2541 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2542 rewriter.replaceOp(op, result);
2543 return success();
2544 }
2545};
2546
2547struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2549
2550 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2551 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2552 Chipset chipset;
2553
2554 LogicalResult
2555 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2556 ConversionPatternRewriter &rewriter) const override {
2557 if (chipset < kGfx950)
2558 return op->emitOpError("permlane_swap is only supported on gfx950+");
2559
2560 Location loc = op.getLoc();
2561 Type i32 = rewriter.getI32Type();
2562 Value src = adaptor.getSrc();
2563 unsigned rowLength = op.getRowLength();
2564 bool fi = op.getFetchInactive();
2565 bool boundctrl = op.getBoundCtrl();
2566
2567 SmallVector<Value> decomposed =
2568 LLVM::decomposeValue(rewriter, loc, src, i32);
2569
2570 SmallVector<Value> permuted;
2571 for (Value v : decomposed) {
2572 Value res;
2573 Type i32pair = LLVM::LLVMStructType::getLiteral(
2574 rewriter.getContext(), {v.getType(), v.getType()});
2575
2576 if (rowLength == 16)
2577 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2578 boundctrl);
2579 else if (rowLength == 32)
2580 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2581 boundctrl);
2582 else
2583 llvm_unreachable("unsupported row length");
2584
2585 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2586 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2587
2588 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2589 LLVM::ICmpPredicate::eq, vdst0, v);
2590
2591 // Per `permlane(16|32)` semantics: if the first extracted element equals
2592 // 'v', the result is the second element; otherwise it is the first.
2593 Value vdstNew =
2594 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2595 permuted.emplace_back(vdstNew);
2596 }
2597
2598 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2599 rewriter.replaceOp(op, result);
2600 return success();
2601 }
2602};
2603
2604static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2605 Value accumulator, Value value, int64_t shift) {
2606 shift = shift % 32;
2607 Value shiftAmount;
2608 if (shift != 0) {
2609 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2610 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2611 }
2612
2613 if (matchPattern(accumulator, mlir::m_Zero()))
2614 return value;
2615
2616 constexpr bool isDisjoint = true;
2617 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2618}
2619
2620template <typename BaseOp>
2621struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
2622 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2623 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
2624
2625 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2626 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2627 Chipset chipset;
2628
2629 LogicalResult
2630 matchAndRewrite(BaseOp op, Adaptor adaptor,
2631 ConversionPatternRewriter &rewriter) const override {
2632 if (chipset < kGfx1250)
2633 return op->emitOpError("make_dma_base is only supported on gfx1250");
2634
2635 Location loc = op.getLoc();
2636
2637 constexpr int32_t constlen = 4;
2638 Value consts[constlen];
2639 for (int64_t i = 0; i < constlen; ++i)
2640 consts[i] = createI32Constant(rewriter, loc, i);
2641
2642 constexpr int32_t sgprslen = constlen;
2643 Value sgprs[sgprslen];
2644 for (int64_t i = 0; i < sgprslen; ++i) {
2645 sgprs[i] = consts[0];
2646 }
2647
2648 sgprs[0] = consts[1];
2649
2650 if constexpr (BaseOp::isGather()) {
2651 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2652
2653 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2654 Type indexType = type.getIndexType();
2655 unsigned indexSize = indexType.getIntOrFloatBitWidth();
2656 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2657 "expected index_size to be 16 or 32");
2658 unsigned idx = (indexSize / 16) - 1;
2659
2660 if (idx)
2661 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2662 }
2663
2664 ValueRange ldsIndices = adaptor.getLdsIndices();
2665 Value lds = adaptor.getLds();
2666 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2667
2669 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2670
2671 ValueRange globalIndices = adaptor.getGlobalIndices();
2672 Value global = adaptor.getGlobal();
2673 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2674
2676 rewriter, loc, globalMemRefType, global, globalIndices);
2677
2678 Type i32 = rewriter.getI32Type();
2679 Type i64 = rewriter.getI64Type();
2680
2681 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2682 Value castForGlobalAddr =
2683 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2684
2685 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2686
2687 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2688 createI64Constant(rewriter, loc, 32));
2689
2690 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2691
2692 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2693 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2694
2695 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2696
2697 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2698 assert(v4i32 && "expected type conversion to succeed");
2699 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2700
2701 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2702 result =
2703 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
2704
2705 rewriter.replaceOp(op, result);
2706 return success();
2707 }
2708};
2709
2710template <typename DescriptorOp>
2711struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
2712 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2713 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
2714
2715 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
2716 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2717 Chipset chipset;
2718
2719 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
2720
2721 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2722 ConversionPatternRewriter &rewriter, Location loc,
2723 Value sgpr0) const {
2724 Value mask = op.getWorkgroupMask();
2725 if (!mask)
2726 return sgpr0;
2727
2728 Type i16 = rewriter.getI16Type();
2729 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2730 Type i32 = rewriter.getI32Type();
2731 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2732 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2733 }
2734
2735 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2736 ConversionPatternRewriter &rewriter, Location loc,
2737 Value sgpr0, ArrayRef<Value> consts) const {
2738 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2739 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2740 "expected type width to be 8, 16, 32, or 64.");
2741 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
2742 Value size = consts[idx];
2743 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2744 }
2745
2746 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
2747 ConversionPatternRewriter &rewriter, Location loc,
2748 Value sgpr0, ArrayRef<Value> consts) const {
2749 if (!adaptor.getAtomicBarrierAddress())
2750 return sgpr0;
2751
2752 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2753 }
2754
2755 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
2756 ConversionPatternRewriter &rewriter, Location loc,
2757 Value sgpr0, ArrayRef<Value> consts) const {
2758 if (!adaptor.getGlobalIncrement())
2759 return sgpr0;
2760
2761 // Value is ignored when in gather mode.
2762 // TODO: emit error earlier?
2763 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2764 }
2765
2766 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
2767 ConversionPatternRewriter &rewriter, Location loc,
2768 Value sgpr0, ArrayRef<Value> consts) const {
2769 if (!op.getPadAmount())
2770 return sgpr0;
2771
2772 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2773 }
2774
2775 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
2776 ConversionPatternRewriter &rewriter, Location loc,
2777 Value sgpr0, ArrayRef<Value> consts) const {
2778 if (!op.getWorkgroupMask())
2779 return sgpr0;
2780
2781 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2782 }
2783
2784 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
2785 ConversionPatternRewriter &rewriter, Location loc,
2786 Value sgpr0, ArrayRef<Value> consts) const {
2787 if (!op.getPadAmount())
2788 return sgpr0;
2789
2790 // pre-condition: padInterval can be a power of two between 2 and 256.
2791 // TODO: Validation if the value breaks the pre-condition.
2792 // If the pre-condition fails, there is a possibility of
2793 // affecting the higher bits. In a following PR implement
2794 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2795 // checked at runtime.
2796 IntegerType i32 = rewriter.getI32Type();
2797 Value padInterval = adaptor.getPadInterval();
2798 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2799 padInterval, false);
2800 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2801 // post-condition: padInterval can be a value between 0 and 7.
2802 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2803 }
2804
2805 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
2806 ConversionPatternRewriter &rewriter, Location loc,
2807 Value sgpr0, ArrayRef<Value> consts) const {
2808 if (!op.getPadAmount())
2809 return sgpr0;
2810
2811 // pre-condition: padAmount is a value between 1-128.
2812 // TODO: Validation if the value breaks the pre-condition.
2813 // If the pre-condition fails, there is a possibility of
2814 // affecting the higher bits. In a following PR implement
2815 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2816 // checked at runtime.
2817 Value padAmount = adaptor.getPadAmount();
2818 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2819 // post-condition: padAmount is a value between 0-127.
2820 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2821 }
2822
2823 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
2824 ConversionPatternRewriter &rewriter,
2825 Location loc, Value sgpr1,
2826 ArrayRef<Value> consts) const {
2827 if (!adaptor.getAtomicBarrierAddress())
2828 return sgpr1;
2829
2830 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2831 auto barrierAddressTy =
2832 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2833 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2834 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
2835 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
2836 atomicBarrierIndices);
2837 IntegerType i32 = rewriter.getI32Type();
2838 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
2839 // that the 3 LSBs are zero.
2840 // TODO: Validation if the value breaks the pre-condition.
2841 // In a following PR implement RuntimeVerifiableOpInterface
2842 // that instruments conditions that need to be checked at runtime.
2843 atomicBarrierAddress =
2844 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2845 atomicBarrierAddress =
2846 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2847 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
2848 atomicBarrierAddress =
2849 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2850 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2851 }
2852
2853 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
2854 ConversionPatternRewriter &rewriter,
2855 Location loc, Value sgpr1, Value sgpr2,
2856 ArrayRef<Value> consts, uint64_t dimX,
2857 uint32_t offset) const {
2858 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
2859 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
2860 SmallVector<OpFoldResult> mixedGlobalSizes =
2861 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
2862 if (mixedGlobalSizes.size() <= dimX)
2863 return {sgpr1, sgpr2};
2864
2865 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
2866 // pre-condition: tensorDimX is less than 2^32-1
2867 // TODO: Validation if the value breaks the pre-condition.
2868 // In a following PR implement RuntimeVerifiableOpInterface that instruments
2869 // conditions that need to be checked at runtime. This could also be fixed
2870 // by saying that mixedGlobalSizes is a DynamicI32List.
2871 Value tensorDimX;
2872 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
2873 tensorDimX =
2874 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2875 } else {
2876 IntegerType i32 = rewriter.getI32Type();
2877 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
2878 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
2879 }
2880
2881 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
2882
2883 Value c16 = createI32Constant(rewriter, loc, 16);
2884 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
2885 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
2886 return {sgpr1, sgpr2};
2887 }
2888
2889 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
2890 ConversionPatternRewriter &rewriter,
2891 Location loc, Value sgpr1, Value sgpr2,
2892 ArrayRef<Value> consts) const {
2893 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
2894 48);
2895 }
2896
2897 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
2898 ConversionPatternRewriter &rewriter,
2899 Location loc, Value sgpr2, Value sgpr3,
2900 ArrayRef<Value> consts) const {
2901 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
2902 80);
2903 }
2904
2905 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
2906 ConversionPatternRewriter &rewriter, Location loc,
2907 Value sgpr, ArrayRef<Value> consts, size_t dimX,
2908 int64_t offset) const {
2909 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
2910 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
2911 SmallVector<OpFoldResult> mixedSharedSizes =
2912 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
2913 if (mixedSharedSizes.size() <= dimX)
2914 return sgpr;
2915
2916 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2917 // pre-condition: tileDimX is less than 2^16-1
2918 // TODO: Validation if the value breaks the pre-condition.
2919 // If the pre-condition fails, there is a possibility of
2920 // affecting the higher bits. In a following PR implement
2921 // RuntimeVerifiableOpInterface that instruments conditions that need to be
2922 // checked at runtime. This could also be fixed by saying that
2923 // mixedSharedSizes is a DynamicI16List.
2924 Value tileDimX;
2925 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
2926 tileDimX =
2927 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2928 } else {
2929 IntegerType i32 = rewriter.getI32Type();
2930 tileDimX = cast<Value>(tileDimXOpFoldResult);
2931 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
2932 }
2933
2934 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2935 }
2936
2937 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
2938 ConversionPatternRewriter &rewriter, Location loc,
2939 Value sgpr3, ArrayRef<Value> consts) const {
2940 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2941 }
2942
2943 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
2944 ConversionPatternRewriter &rewriter, Location loc,
2945 Value sgpr4, ArrayRef<Value> consts) const {
2946 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2947 }
2948
2949 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
2950 ConversionPatternRewriter &rewriter, Location loc,
2951 Value sgpr4, ArrayRef<Value> consts) const {
2952 auto type = cast<VectorType>(op.getIndices().getType());
2953 ArrayRef<int64_t> shape = type.getShape();
2954 assert(shape.size() == 1 && "expected shape to be of rank 1.");
2955 unsigned length = shape.back();
2956 assert(0 < length && length <= 16 && "expected length to be at most 16.");
2957 Value value = createI32Constant(rewriter, loc, length);
2958 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
2959 }
2960
2961 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
2962 ConversionPatternRewriter &rewriter,
2963 Location loc, Value sgpr4,
2964 ArrayRef<Value> consts) const {
2965 if constexpr (DescriptorOp::isGather())
2966 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
2967 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
2968 }
2969
2970 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
2971 ConversionPatternRewriter &rewriter, Location loc,
2972 Value sgpr4, ArrayRef<Value> consts) const {
2973 // Value is ignored when in gather mode.
2974 if constexpr (DescriptorOp::isGather())
2975 return sgpr4;
2976 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2977 }
2978
2979 std::pair<Value, Value>
2980 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
2981 ConversionPatternRewriter &rewriter, Location loc,
2982 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2983 size_t dimX, int64_t offset) const {
2984 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
2985 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
2986 SmallVector<OpFoldResult> mixedGlobalStrides =
2987 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
2988
2989 if (mixedGlobalStrides.size() <= (dimX + 1))
2990 return {sgprY, sgprZ};
2991
2992 OpFoldResult tensorDimXStrideOpFoldResult =
2993 *(mixedGlobalStrides.rbegin() + dimX + 1);
2994 // pre-condition: tensorDimXStride is less than 2^48-1
2995 // TODO: Validation if the value breaks the pre-condition.
2996 // In a following PR implement RuntimeVerifiableOpInterface that instruments
2997 // conditions that need to be checked at runtime.
2998 Value tensorDimXStride;
2999 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3000 tensorDimXStride =
3001 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3002 else
3003 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3004
3005 constexpr int64_t first48bits = (1ll << 48) - 1;
3006 Value mask = createI64Constant(rewriter, loc, first48bits);
3007 tensorDimXStride =
3008 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3009 IntegerType i32 = rewriter.getI32Type();
3010 Value tensorDimXStrideLow =
3011 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3012 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3013
3014 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3015 Value shiftVal = createI64Constant(rewriter, loc, shift);
3016 Value tensorDimXStrideHigh =
3017 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3018 tensorDimXStrideHigh =
3019 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3020 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3021 offset + shift);
3022 return {sgprY, sgprZ};
3023 }
3024
3025 std::pair<Value, Value>
3026 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3027 ConversionPatternRewriter &rewriter, Location loc,
3028 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3029 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3030 0, 160);
3031 }
3032
3033 std::pair<Value, Value>
3034 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3035 ConversionPatternRewriter &rewriter, Location loc,
3036 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3037 // Value is ignored when in gather mode.
3038 if constexpr (DescriptorOp::isGather())
3039 return {sgpr5, sgpr6};
3040 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3041 1, 208);
3042 }
3043
3044 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3045 ConversionPatternRewriter &rewriter, Location loc,
3046 ArrayRef<Value> consts) const {
3047 Value sgprs[8];
3048 for (int64_t i = 0; i < 8; ++i) {
3049 sgprs[i] = consts[0];
3050 }
3051
3052 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3053 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3054 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3055 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3056 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3057 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3058 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3059 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3060
3061 sgprs[1] =
3062 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3063 std::tie(sgprs[1], sgprs[2]) =
3064 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3065 std::tie(sgprs[2], sgprs[3]) =
3066 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3067
3068 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3069 sgprs[4] =
3070 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3071 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3072 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3073 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3074 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3075 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3076
3077 IntegerType i32 = rewriter.getI32Type();
3078 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3079 assert(v8i32 && "expected type conversion to succeed");
3080 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3081
3082 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3083 dgroup1 =
3084 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3085 }
3086
3087 return dgroup1;
3088 }
3089
3090 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3091 ConversionPatternRewriter &rewriter, Location loc,
3092 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3093 int64_t offset) const {
3094 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3095 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3096 SmallVector<OpFoldResult> mixedGlobalSizes =
3097 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3098 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3099 return sgpr0;
3100
3101 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3102 Value tensorDimX;
3103 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3104 tensorDimX =
3105 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3106 } else {
3107 IntegerType i32 = rewriter.getI32Type();
3108 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3109 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3110 }
3111
3112 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3113 }
3114
3115 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3116 ConversionPatternRewriter &rewriter, Location loc,
3117 Value sgpr0, ArrayRef<Value> consts) const {
3118 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3119 }
3120
3121 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3122 Location loc, Value accumulator,
3123 Value value, int64_t shift) const {
3124
3125 IntegerType i32 = rewriter.getI32Type();
3126 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3127 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3128 }
3129
3130 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3131 ConversionPatternRewriter &rewriter, Location loc,
3132 Value sgpr1, ArrayRef<Value> consts,
3133 int64_t offset) const {
3134 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3135 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3136 }
3137
3138 std::pair<Value, Value>
3139 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3140 ConversionPatternRewriter &rewriter, Location loc,
3141 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3142 int64_t offset) const {
3143 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3144 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3145 globalAddrIncrement, offset);
3146 Value shift = createI64Constant(rewriter, loc, 32);
3147 globalAddrIncrement =
3148 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3149 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3150 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3151 globalAddrIncrement, offset + 32);
3152 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3153 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3154 return {sgpr2, sgpr3};
3155 }
3156
3157 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3158 ConversionPatternRewriter &rewriter,
3159 Location loc, Value sgpr1,
3160 ArrayRef<Value> consts) const {
3161 Value ldsIncrement = op.getLdsIncrement();
3162 constexpr int64_t dim = 3;
3163 constexpr int64_t offset = 32;
3164 if (!ldsIncrement)
3165 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3166 offset);
3167 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3168 offset);
3169 }
3170
3171 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3172 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3173 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3174 Value globalIncrement = op.getGlobalIncrement();
3175 constexpr int32_t dim = 2;
3176 constexpr int32_t offset = 64;
3177 if (!globalIncrement)
3178 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3179 consts, dim, offset);
3180 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3181 consts, offset);
3182 }
3183
3184 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3185 ConversionPatternRewriter &rewriter, Location loc,
3186 Value sgpr3, ArrayRef<Value> consts,
3187 int32_t offset) const {
3188 Value iterationCount = adaptor.getIterationCount();
3189 IntegerType i32 = rewriter.getI32Type();
3190 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3191 // TODO: validation if the value breaks the pre-condition.
3192 // If the pre-condition fails, there is a possibility of
3193 // affecting the higher bits. In a following PR implement
3194 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3195 // checked at runtime.
3196 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3197 iterationCount =
3198 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3199 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3200 }
3201
3202 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3203 ConversionPatternRewriter &rewriter,
3204 Location loc, Value sgpr3,
3205 ArrayRef<Value> consts) const {
3206 Value iterateCount = op.getIterationCount();
3207 constexpr int32_t dim = 2;
3208 constexpr int32_t offset = 112;
3209 if (!iterateCount)
3210 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3211 offset);
3212
3213 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3214 }
3215
3216 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3217 ConversionPatternRewriter &rewriter, Location loc,
3218 ArrayRef<Value> consts) const {
3219 if constexpr (DescriptorOp::isGather())
3220 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3221 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3222 }
3223
3224 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3225 ConversionPatternRewriter &rewriter, Location loc,
3226 ArrayRef<Value> consts) const {
3227 IntegerType i32 = rewriter.getI32Type();
3228 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3229 assert(v4i32 && "expected type conversion to succeed.");
3230
3231 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3232 if (onlyNeedsTwoDescriptors)
3233 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3234
3235 constexpr int64_t sgprlen = 4;
3236 Value sgprs[sgprlen];
3237 for (int i = 0; i < sgprlen; ++i)
3238 sgprs[i] = consts[0];
3239
3240 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3241 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3242 sgprs[1], consts);
3243 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3244 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3245 sgprs[3] =
3246 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3247
3248 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3249 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3250 dgroup2 =
3251 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3252
3253 return dgroup2;
3254 }
3255
3256 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3257 ConversionPatternRewriter &rewriter, Location loc,
3258 ArrayRef<Value> consts, bool firstHalf) const {
3259 IntegerType i32 = rewriter.getI32Type();
3260 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3261 assert(v4i32 && "expected type conversion to succeed.");
3262
3263 Value indices = adaptor.getIndices();
3264 auto vectorType = cast<VectorType>(indices.getType());
3265 unsigned length = vectorType.getShape().back();
3266 Type elementType = vectorType.getElementType();
3267 unsigned maxLength = elementType == i32 ? 4 : 8;
3268 int32_t offset = firstHalf ? 0 : maxLength;
3269 unsigned discountedLength =
3270 std::max(static_cast<int32_t>(length - offset), 0);
3271
3272 unsigned targetSize = std::min(maxLength, discountedLength);
3273
3274 SmallVector<Value> indicesVector;
3275 for (unsigned i = offset; i < targetSize + offset; ++i) {
3276 Value idx;
3277 if (i < consts.size())
3278 idx = consts[i];
3279 else
3280 idx = createI32Constant(rewriter, loc, i);
3281 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
3282 indicesVector.push_back(elem);
3283 }
3284
3285 SmallVector<Value> indicesI32Vector;
3286 if (elementType == i32) {
3287 indicesI32Vector = indicesVector;
3288 } else {
3289 for (unsigned i = 0; i < targetSize; ++i) {
3290 Value index = indicesVector[i];
3291 indicesI32Vector.push_back(
3292 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3293 }
3294 if ((targetSize % 2) != 0)
3295 // Add padding when not divisible by two.
3296 indicesI32Vector.push_back(consts[0]);
3297 }
3298
3299 SmallVector<Value> indicesToInsert;
3300 if (elementType == i32) {
3301 indicesToInsert = indicesI32Vector;
3302 } else {
3303 unsigned size = indicesI32Vector.size() / 2;
3304 for (unsigned i = 0; i < size; ++i) {
3305 Value first = indicesI32Vector[2 * i];
3306 Value second = indicesI32Vector[2 * i + 1];
3307 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3308 indicesToInsert.push_back(joined);
3309 }
3310 }
3311
3312 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3313 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3314 dgroup =
3315 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3316
3317 return dgroup;
3318 }
3319
3320 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3321 ConversionPatternRewriter &rewriter, Location loc,
3322 ArrayRef<Value> consts) const {
3323 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
3324 }
3325
3326 std::pair<Value, Value>
3327 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3328 ConversionPatternRewriter &rewriter, Location loc,
3329 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
3330 constexpr int32_t dim = 3;
3331 constexpr int32_t offset = 0;
3332 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3333 dim, offset);
3334 }
3335
3336 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3337 ConversionPatternRewriter &rewriter,
3338 Location loc, Value sgpr1, Value sgpr2,
3339 ArrayRef<Value> consts) const {
3340 constexpr int32_t dim = 4;
3341 constexpr int32_t offset = 48;
3342 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3343 offset);
3344 }
3345
3346 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3347 ConversionPatternRewriter &rewriter, Location loc,
3348 Value sgpr2, ArrayRef<Value> consts) const {
3349 constexpr int32_t dim = 4;
3350 constexpr int32_t offset = 80;
3351 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3352 }
3353
3354 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3355 ConversionPatternRewriter &rewriter, Location loc,
3356 ArrayRef<Value> consts) const {
3357 if constexpr (DescriptorOp::isGather())
3358 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3359 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3360 }
3361
3362 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3363 ConversionPatternRewriter &rewriter, Location loc,
3364 ArrayRef<Value> consts) const {
3365 IntegerType i32 = rewriter.getI32Type();
3366 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3367 assert(v4i32 && "expected type conversion to succeed.");
3368 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3369 if (onlyNeedsTwoDescriptors)
3370 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3371
3372 constexpr int32_t sgprlen = 4;
3373 Value sgprs[sgprlen];
3374 for (int i = 0; i < sgprlen; ++i)
3375 sgprs[i] = consts[0];
3376
3377 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3378 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3379 std::tie(sgprs[1], sgprs[2]) =
3380 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3381 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3382
3383 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3384 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3385 dgroup3 =
3386 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3387
3388 return dgroup3;
3389 }
3390
3391 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3392 ConversionPatternRewriter &rewriter, Location loc,
3393 ArrayRef<Value> consts) const {
3394 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
3395 }
3396
3397 LogicalResult
3398 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3399 ConversionPatternRewriter &rewriter) const override {
3400 if (chipset < kGfx1250)
3401 return op->emitOpError(
3402 "make_dma_descriptor is only supported on gfx1250");
3403
3404 Location loc = op.getLoc();
3405
3406 SmallVector<Value> consts;
3407 for (int64_t i = 0; i < 8; ++i)
3408 consts.push_back(createI32Constant(rewriter, loc, i));
3409
3410 Value dgroup0 = this->getDGroup0(adaptor);
3411 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3412 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3413 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3414 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3415 rewriter.replaceOpWithMultiple(op, {results});
3416 return success();
3417 }
3418};
3419
3420template <typename SourceOp, typename TargetOp>
3421struct AMDGPUTensorLoadStoreOpLowering
3422 : public ConvertOpToLLVMPattern<SourceOp> {
3423 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3425 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
3426 Chipset chipset)
3427 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3428 Chipset chipset;
3429
3430 LogicalResult
3431 matchAndRewrite(SourceOp op, Adaptor adaptor,
3432 ConversionPatternRewriter &rewriter) const override {
3433 if (chipset < kGfx1250)
3434 return op->emitOpError("is only supported on gfx1250");
3435
3436 ValueRange desc = adaptor.getDesc();
3437 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3438 desc[3], /*cachePolicy=*/0,
3439 /*alias_scopes=*/nullptr,
3440 /*noalias_scopes=*/nullptr,
3441 /*tbaa=*/nullptr);
3442 return success();
3443 }
3444};
3445
3446struct ConvertAMDGPUToROCDLPass
3447 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3448 using Base::Base;
3449
3450 void runOnOperation() override {
3451 MLIRContext *ctx = &getContext();
3452 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
3453 if (failed(maybeChipset)) {
3454 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
3455 return signalPassFailure();
3456 }
3457
3458 RewritePatternSet patterns(ctx);
3459 LLVMTypeConverter converter(ctx);
3460
3461 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
3462 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3463 LLVMConversionTarget target(getContext());
3464 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3465 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3466 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3467 if (failed(applyPartialConversion(getOperation(), target,
3468 std::move(patterns))))
3469 signalPassFailure();
3470 }
3471};
3472} // namespace
3473
3475 TypeConverter &typeConverter) {
3477 typeConverter, [](gpu::AddressSpace space) {
3478 switch (space) {
3479 case gpu::AddressSpace::Global:
3480 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3481 case gpu::AddressSpace::Workgroup:
3482 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3483 case gpu::AddressSpace::Private:
3484 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3485 }
3486 llvm_unreachable("unknown address space enum value");
3487 });
3488}
3489
3491 TypeConverter &typeConverter) {
3492 typeConverter.addTypeAttributeConversion(
3493 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
3494 -> TypeConverter::AttributeConversionResult {
3495 MLIRContext *ctx = as.getContext();
3496 Type i64 = IntegerType::get(ctx, 64);
3497 switch (as.getValue()) {
3498 case amdgpu::AddressSpace::FatRawBuffer:
3499 return IntegerAttr::get(i64, 7);
3500 case amdgpu::AddressSpace::BufferRsrc:
3501 return IntegerAttr::get(i64, 8);
3502 case amdgpu::AddressSpace::FatStructuredBuffer:
3503 return IntegerAttr::get(i64, 9);
3504 }
3505 return TypeConverter::AttributeConversionResult::abort();
3506 });
3507 typeConverter.addConversion([&](TDMBaseType type) -> Type {
3508 Type i32 = IntegerType::get(type.getContext(), 32);
3509 return typeConverter.convertType(VectorType::get(4, i32));
3510 });
3511 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
3512 Type i32 = IntegerType::get(type.getContext(), 32);
3513 return typeConverter.convertType(VectorType::get(4, i32));
3514 });
3515 typeConverter.addConversion(
3516 [&](TDMDescriptorType type,
3517 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
3518 Type i32 = IntegerType::get(type.getContext(), 32);
3519 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3520 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3521 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
3522 return success();
3523 });
3524
3525 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
3526 ValueRange inputs,
3528 // Only create unrealized_conversion_cast for TDMDescriptorType.
3529 // All other types which are not expected, should be
3530 // materialized by other target materialization functions.
3531 if (inputs.size() != 1)
3532 return {};
3533
3534 if (!isa<TDMDescriptorType>(inputs[0].getType()))
3535 return {};
3536
3537 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3538 return cast.getResults();
3539 };
3540
3541 typeConverter.addTargetMaterialization(addUnrealizedCast);
3542}
3543
3546 Chipset chipset) {
3548 patterns
3549 .add<FatRawBufferCastLowering,
3550 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3551 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3552 RawBufferOpLowering<RawBufferAtomicFaddOp,
3553 ROCDL::RawPtrBufferAtomicFaddOp>,
3554 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3555 ROCDL::RawPtrBufferAtomicFmaxOp>,
3556 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3557 ROCDL::RawPtrBufferAtomicSmaxOp>,
3558 RawBufferOpLowering<RawBufferAtomicUminOp,
3559 ROCDL::RawPtrBufferAtomicUminOp>,
3560 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3561 ROCDL::RawPtrBufferAtomicCmpSwap>,
3562 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3563 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3564 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3565 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3566 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3567 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3568 GatherToLDSOpLowering, TransposeLoadOpLowering,
3569 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3570 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3571 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3572 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3573 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3574 ROCDL::TensorLoadToLDSOp>,
3575 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3576 ROCDL::TensorStoreFromLDSOp>>(
3577 converter, chipset);
3578 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
3579}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:219
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:218
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:207
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14