MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Matchers.h"
25#include "mlir/Pass/Pass.h"
26
28
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
33#include <optional>
34
35namespace mlir {
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
38} // namespace mlir
39
40using namespace mlir;
41using namespace mlir::amdgpu;
42
43// Define commonly used chipsets versions for convenience.
44constexpr Chipset kGfx908 = Chipset(9, 0, 8);
45constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
46constexpr Chipset kGfx942 = Chipset(9, 4, 2);
47constexpr Chipset kGfx950 = Chipset(9, 5, 0);
48constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
49
50/// Convert an unsigned number `val` to i32.
51static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
52 Location loc, Value val) {
53 IntegerType i32 = rewriter.getI32Type();
54 // Force check that `val` is of int type.
55 auto valTy = cast<IntegerType>(val.getType());
56 if (i32 == valTy)
57 return val;
58 return valTy.getWidth() > 32
59 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61}
62
63static Value createI32Constant(ConversionPatternRewriter &rewriter,
64 Location loc, int32_t value) {
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
66}
67
68/// Convert an unsigned number `val` to i64.
69static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
70 Location loc, Value val) {
71 IntegerType i64 = rewriter.getI64Type();
72 // Force check that `val` is of int type.
73 auto valTy = cast<IntegerType>(val.getType());
74 if (i64 == valTy)
75 return val;
76 return valTy.getWidth() > 64
77 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
79}
80
81static Value createI64Constant(ConversionPatternRewriter &rewriter,
82 Location loc, int64_t value) {
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
84}
85
86/// Returns the linear index used to access an element in the memref.
87static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
88 Location loc, MemRefDescriptor &memRefDescriptor,
90 IntegerType i32 = rewriter.getI32Type();
92 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
93 if (stride != 1) { // Skip if stride is 1.
94 Value strideValue =
95 ShapedType::isDynamic(stride)
96 ? convertUnsignedToI32(rewriter, loc,
97 memRefDescriptor.stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
100 }
101 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
102 : increment;
103 }
104 return index ? index : createI32Constant(rewriter, loc, 0);
105}
106
107/// Compute the contents of the `num_records` field for a given memref
108/// descriptor - that is, the number of bytes that's one element past the
109/// greatest possible valid index into the memref.
110static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
111 MemRefType memrefType,
112 MemRefDescriptor &memrefDescriptor,
113 ArrayRef<int64_t> strides, int64_t elementByteWidth,
114 amdgpu::Chipset chipset, bool boundsCheck) {
115 if (chipset >= kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
117 return createI64Constant(rewriter, loc, first45bits);
118 }
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
122 ArrayRef<int64_t> shape = memrefType.getShape();
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(shape[i] * strides[i], size);
125 size = size * elementByteWidth;
126 return createI64Constant(rewriter, loc, size);
127 }
128 Value maxIndex;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.size(rewriter, loc, i);
131 Value stride = memrefDescriptor.stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
133 maxIndex = maxIndex
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
135 : maxThisDim;
136 }
137 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
138 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
140}
141
142static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
143 Value basePointer, Value numRecords,
144 bool boundsCheck, amdgpu::Chipset chipset,
145 Value cacheSwizzleStride = nullptr,
146 unsigned addressSpace = 8) {
147 // The stride value is generally 0. However, on MI-300 and onward, you can
148 // enable a cache swizzling mode by setting bit 14 of the stride field
149 // and setting that stride to a cache stride.
150 Type i16 = rewriter.getI16Type();
151 Value stride;
152 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
158 /*isDisjoint=*/true);
159 } else {
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
162 }
163
164 uint32_t flags = 0;
165 if (chipset >= kGfx1250) {
166 // Flag word:
167 // bit 0: swizzle
168 // bit 1: 0 means (total_offset + payload > numRecords)
169 // 1 means ((total_offset + payload >) numRecords) || ((offset +
170 // payload) > stride) only applied when swizzle_enable = 0. keep at
171 // zero.
172 // whether oob is done depends on numRecords.
173 // bits 2-3: Type (must be 0)
174 } else {
175 // Get the number of elements.
176 // Flag word:
177 // bits 0-11: dst sel, ignored by these intrinsics
178 // bits 12-14: data format (ignored, must be nonzero, 7=float)
179 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
180 // bit 19: In nested heap (0 here)
181 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
182 // bits 21-22: Index stride for swizzles (N/A)
183 // bit 23: Add thread ID (0)
184 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
185 // bits 25-26: Reserved (0)
186 // bit 27: Buffer is non-volatile (CDNA only)
187 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
188 // none, 3 = either swizzles or testing against offset field) RDNA only
189 // bits 30-31: Type (must be 0)
190 flags |= (7 << 12) | (4 << 15);
191 if (chipset.majorVersion >= 10) {
192 flags |= (1 << 24);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
195 }
196 }
197 Value flagsConst = createI32Constant(rewriter, loc, flags);
198 Type rsrcType =
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
202 return resource;
203}
204
205namespace {
206struct FatRawBufferCastLowering
207 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
208 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
210 chipset(chipset) {}
211
212 Chipset chipset;
213
214 LogicalResult
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
221 MemRefDescriptor descriptor(memRef);
222
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
225 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
226
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError("Can't lower non-stride-offset memrefs");
231
232 Value numRecords = adaptor.getValidBytes();
233 if (!numRecords)
234 numRecords =
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
237
238 Value basePointer =
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
241 memrefType)
242 : descriptor.alignedPtr(rewriter, loc);
243
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
248
249 bool hasSizes = memrefType.getRank() > 0;
250 // No need to unpack() and pack() all the individual sizes and strides,
251 // so we'll just extract the arrays.
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
255 : Value{};
256 Value strides =
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
259 : Value{};
260
261 Value fatPtr = makeBufferRsrc(
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
264
265 Value result = MemRefDescriptor::poison(
266 rewriter, loc,
267 getTypeConverter()->convertType(op.getResult().getType()));
268 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
269 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
274 if (hasSizes) {
275 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
279 }
280 rewriter.replaceOp(op, result);
281 return success();
282 }
283};
284
285/// Define lowering patterns for raw buffer ops
286template <typename GpuOp, typename Intrinsic>
287struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
288 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
290
291 Chipset chipset;
292 static constexpr uint32_t maxVectorOpWidth = 128;
293
294 LogicalResult
295 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
301
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
304
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref) // no write component to this op
307 storeData = Value();
308 Type wantedDataType;
309 if (storeData)
310 wantedDataType = storeData.getType();
311 else
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
313
314 Value atomicCmpData = Value();
315 // Operand index 1 of a load is the indices, trying to read them can crash.
316 if (storeData) {
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
320 }
321
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
323
324 Type i32 = rewriter.getI32Type();
325
326 // Get the type size in bytes.
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
329 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
330 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
331
332 // If we want to load a vector<NxT> with total size <= 32
333 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
334 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
335 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
336 // so bitcast any floats to integers.
337 Type llvmBufferValType = llvmWantedDataType;
338 if (atomicCmpData) {
339 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
342 }
343 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
345 uint32_t elemBits =
346 dataLayout.getTypeSizeInBits(dataVector.getElementType());
347 uint32_t totalBits = elemBits * vecLen;
348 bool usePackedFp16 =
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) + " bits, but we call for " +
354 Twine(totalBits) +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError("Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
363 } else {
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
366 }
367 }
368 }
369 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
370 // Buffer intrinsics doesn't support 1-element vectors, cast them to
371 // scalars.
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
374 }
375
376 SmallVector<Value, 6> args;
377 if (storeData) {
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
382 } else {
383 args.push_back(storeData);
384 }
385 }
386
387 if (atomicCmpData) {
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
392 } else {
393 args.push_back(atomicCmpData);
394 }
395 }
396
397 // Construct buffer descriptor from memref, attributes
398 int64_t offset = 0;
399 SmallVector<int64_t, 5> strides;
400 if (failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
402
403 MemRefDescriptor memrefDescriptor(memref);
404
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
407 Value numRecords =
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
410 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
413
414 // Indexing (voffset)
415 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
419 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
421 extraOffsetConst)
422 : extraOffsetConst;
423 }
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
426
427 // SGPR offset.
428 Value sgprOffset = adaptor.getSgprOffset();
429 if (!sgprOffset)
430 sgprOffset = createI32Constant(rewriter, loc, 0);
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
433
434 // bit 0: GLC = 0 (atomics drop value, less coherency)
435 // bits 1-2: SLC, DLC = 0 (similarly)
436 // bit 3: swizzled (0 for raw)
437 args.push_back(createI32Constant(rewriter, loc, 0));
438
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
440 llvmBufferValType);
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
443 if (lowered->getNumResults() == 1) {
444 Value replacement = lowered->getResult(0);
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
448 }
449 rewriter.replaceOp(gpuOp, replacement);
450 } else {
451 rewriter.eraseOp(gpuOp);
452 }
453 return success();
454 }
455};
456
457// TODO: AMDGPU backend already have all this bitpacking logic, we should move
458// it to some common place.
459/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
460/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
461/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
462/// Vmcnt = Waitcnt[15:10] (gfx11)
463/// Expcnt = Waitcnt[6:4] (pre-gfx11)
464/// Expcnt = Waitcnt[2:0] (gfx11)
465/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
466/// Lgkmcnt = Waitcnt[13:8] (gfx10)
467/// Lgkmcnt = Waitcnt[9:4] (gfx11)
468static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
469 unsigned expcnt, unsigned lgkmcnt) {
470 if (chipset.majorVersion < 9) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
475 }
476 if (chipset.majorVersion == 9) {
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
484 }
485 if (chipset.majorVersion == 10) {
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
493 }
494 if (chipset.majorVersion == 11) {
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
499 }
500 return failure();
501}
502
503struct MemoryCounterWaitOpLowering
504 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
505 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
506 Chipset chipset)
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
508 chipset(chipset) {}
509
510 Chipset chipset;
511
512 LogicalResult
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
519
520 if (std::optional<int> load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
522
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
525
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
528
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
531
532 rewriter.eraseOp(op);
533 return success();
534 }
535
536 if (adaptor.getTensor())
537 return op.emitOpError("unsupported chipset");
538
539 auto getVal = [](Attribute attr) -> unsigned {
540 if (attr)
541 return cast<IntegerAttr>(attr).getInt();
542
543 // This value will be clamped to the maximum value for the chipset.
544 return 1024;
545 };
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
548
549 unsigned vmcnt = 1024;
550 Attribute load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
552 if (load && store) {
553 vmcnt = getVal(load) + getVal(store);
554 } else if (load) {
555 vmcnt = getVal(load);
556 } else if (store) {
557 vmcnt = getVal(store);
558 }
559
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
561 if (failed(waitcnt))
562 return op.emitOpError("unsupported chipset");
563
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
565 return success();
566 }
567};
568
569struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
570 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
572
573 Chipset chipset;
574
575 LogicalResult
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter) const override {
578 Location loc = op.getLoc();
579 // This ensures that waits on global memory aren't introduced on
580 // chips that don't have the BackOffBarrier feature enabled in LLVM.
581 bool requiresInlineAsm = chipset < kGfx90a;
582
583 Attribute mmra =
584 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
585 // Note: while there *is* a workgroup-one-as scope, this, when combined with
586 // the MMRA, will lead to the fence having no effect. This is because the
587 // codepaths for an atomic load or store will observe that a
588 // one-address-space atomic to LDS requires no synchronization because
589 // operations on LDS are totally ordered with respect to each other, and so
590 // will not emit the correct waitcnt operations that these fences are
591 // intended to produce. Therefore, we use a broader type of fence and rely
592 // on the MMRA to relax it to the semantics we want.
593 StringRef scope = "workgroup";
594
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints = "";
603 LLVM::InlineAsmOp::create(
604 rewriter, loc,
605 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
606 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
607 /*is_align_stack=*/false, LLVM::TailCallKind::None,
608 /*asm_dialect=*/asmDialectAttr,
609 /*operand_attrs=*/ArrayAttr());
610 } else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
612 } else {
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
615 }
616
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
621 return success();
622 }
623};
624
625struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
626 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
628
629 Chipset chipset;
630
631 LogicalResult
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
636 return success();
637 }
638};
639
640} // namespace
641
642/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
643/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
644///
645/// Specifically:
646/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
647/// allows bf16. Newer MFMAs support bf16 types on operand, check
648/// IntrinsicsAMDGPU.td file for reference.
649/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
650/// instead, which is what the f8f6f4 intrinsics use.
651/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
652/// integer.
653///
654/// Note that the type of `input` has already been LLVM type converted:
655/// therefore 8-bit and smaller floats are represented as their corresponding
656/// `iN` integers.
657static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
658 Location loc, Value input,
659 bool allowBf16 = true) {
660 Type inputType = input.getType();
661 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
668 rewriter, loc,
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
674 32);
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
677 input);
678 }
679 }
680 return input;
681}
682
683/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
684static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter,
685 Location loc, Value input,
686 bool allowBf16 = true) {
687 Type inputType = input.getType();
688 auto vectorType = cast<VectorType>(inputType);
689 // bf16 -> i16 when not allowed (pre-gfx950).
690 if (vectorType.getElementType().isBF16() && !allowBf16)
691 return LLVM::BitcastOp::create(
692 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
693 // i8/fp8 vectors -> vector<Nxi32>.
694 if (isa<IntegerType>(vectorType.getElementType()) &&
695 vectorType.getElementTypeBitWidth() <= 8) {
696 int64_t numWords = llvm::divideCeil(
697 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
698 return LLVM::BitcastOp::create(
699 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
700 }
701 return input;
702}
703
704/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
705/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
706///
707/// Specifically:
708/// 1. If `input` is a i8 value, zero extend it to i32
709/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
710///
711/// Note that the type of `input` has already been LLVM type converted:
712/// therefore 8-bit and smaller floats are represented as their corresponding
713/// `iN` integers.
714static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
715 Value input) {
716 return TypeSwitch<Type, Value>(input.getType())
717 .Case([&](IntegerType) {
718 // Handle scalar i8: zero extend to i32.
719 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
720 input);
721 })
722 .Case([&](VectorType vectorType) {
723 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
724 int64_t numElements = vectorType.getNumElements();
725 assert((numElements == 4 || numElements == 8) &&
726 "scale operand must be a vector of length 4 or 8");
727 IntegerType outputType =
728 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
729 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
730 })
731 .DefaultUnreachable("unexpected input type for scale operand");
732}
733
734/// Maps f8 scale element types to WMMA scale format codes.
735static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
737 .Case([](Float8E8M0FNUType) { return 0; })
738 .Case([](Float8E4M3FNType) { return 2; })
739 .Default(std::nullopt);
740}
741
742/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
743/// and scale block size (16 or 32).
744static std::optional<StringRef>
746 if (m == 16 && n == 16 && k == 128)
747 return isScale16
748 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
749 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
750
751 if (m == 32 && n == 16 && k == 128)
752 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
753 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
754
755 return std::nullopt;
756}
757
758/// Push an input operand. If it is a float type, nothing to do. If it is
759/// an integer type, then we need to also push its signdness (1 for signed, 0
760/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
761/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
762/// We also need to convert bfloat inputs to i16 to account for the bfloat
763/// intrinsics having been defined before the AMD backend supported bfloat. We
764/// similarly need to pack 8-bit float types into integers as if they were i8
765/// (which they are for the backend's purposes).
767 ConversionPatternRewriter &rewriter, Location loc,
768 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
769 Value mlirInput, SmallVectorImpl<Value> &operands,
770 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
771 Type inputType = llvmInput.getType();
772 auto vectorType = dyn_cast<VectorType>(inputType);
773 if (!vectorType) {
774 operands.push_back(llvmInput);
775 return;
776 }
777 Type elemType = vectorType.getElementType();
778 if (elemType.getIntOrFloatBitWidth() > 8) {
779 operands.push_back(llvmInput);
780 return;
781 }
782
783 // We need to check the type of the input before conversion to properly test
784 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
785 // fp8/int8 information is lost during the conversion process.
786 auto mlirInputType = cast<VectorType>(mlirInput.getType());
787 bool isInputInteger = mlirInputType.getElementType().isInteger();
788 if (isInputInteger) {
789 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
790 bool localIsUnsigned = isUnsigned;
791 if (elemType.isUnsignedInteger()) {
792 localIsUnsigned = true;
793 } else if (elemType.isSignedInteger()) {
794 localIsUnsigned = false;
795 }
796 attrs.push_back(
797 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
798 }
799
800 int64_t numBits =
801 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
802 Type i32 = rewriter.getI32Type();
803 Type intrinsicInType = numBits <= 32
804 ? (Type)rewriter.getIntegerType(numBits)
805 : (Type)VectorType::get(numBits / 32, i32);
806 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
807 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
808 loc, llvmIntrinsicInType, llvmInput);
809 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
810 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
811 // Add in the zeros here.
812 if (numBits < 32)
813 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
814 operands.push_back(castInput);
815}
816
817/// Push the output operand. For many cases this is only pushing the output in
818/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
819/// since the same numbers of VGPRs is used, we need to decide if to store the
820/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
821/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
822/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
823/// as the instructions have been changed to return fewer registers instead.
824static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
825 Location loc,
826 const TypeConverter *typeConverter,
827 Value output, int32_t subwordOffset,
828 bool clamp, SmallVectorImpl<Value> &operands,
830 Type inputType = output.getType();
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 Type elemType = vectorType.getElementType();
833 operands.push_back(output);
834 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
835 attrs.push_back(
836 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
837 } else if (elemType.isInteger(32)) {
838 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
839 }
840}
841
842/// Return true if `type` is the E5M2 variant of an 8-bit float that is
843/// supported by the `_bf8` instructions on the given `chipset`.
844static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
845 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
846 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
847}
848
849/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
850/// supported by the `_fp8` instructions on the given `chipset`.
851static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
852 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
853 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
854}
855
856/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
857/// if one exists. This includes checking to ensure the intrinsic is supported
858/// on the architecture you are compiling for.
859static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
860 Chipset chipset) {
861 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
862 b = mfma.getBlocks();
863 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
864 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
865
866 if (sourceElem.isF32() && destElem.isF32()) {
867 if (mfma.getReducePrecision() && chipset >= kGfx942) {
868 if (m == 32 && n == 32 && k == 4 && b == 1)
869 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
870 if (m == 16 && n == 16 && k == 8 && b == 1)
871 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
872 }
873 if (m == 32 && n == 32 && k == 1 && b == 2)
874 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
875 if (m == 16 && n == 16 && k == 1 && b == 4)
876 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
877 if (m == 4 && n == 4 && k == 1 && b == 16)
878 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
879 if (m == 32 && n == 32 && k == 2 && b == 1)
880 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
881 if (m == 16 && n == 16 && k == 4 && b == 1)
882 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
883 }
884
885 if (sourceElem.isF16() && destElem.isF32()) {
886 if (chipset >= kGfx950) {
887 if (m == 32 && n == 32 && k == 16 && b == 1)
888 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 && b == 1)
890 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
891 }
892 if (m == 32 && n == 32 && k == 4 && b == 2)
893 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
894 if (m == 16 && n == 16 && k == 4 && b == 4)
895 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
896 if (m == 4 && n == 4 && k == 4 && b == 16)
897 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
898 if (m == 32 && n == 32 && k == 8 && b == 1)
899 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
900 if (m == 16 && n == 16 && k == 16 && b == 1)
901 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
902 }
903
904 if (sourceElem.isBF16() && destElem.isF32()) {
905 if (chipset >= kGfx950) {
906 if (m == 32 && n == 32 && k == 16 && b == 1)
907 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
908 if (m == 16 && n == 16 && k == 32 && b == 1)
909 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
910 }
911 if (chipset >= kGfx90a) {
912 if (m == 32 && n == 32 && k == 4 && b == 2)
913 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
914 if (m == 16 && n == 16 && k == 4 && b == 4)
915 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
916 if (m == 4 && n == 4 && k == 4 && b == 16)
917 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
918 if (m == 32 && n == 32 && k == 8 && b == 1)
919 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 16 && b == 1)
921 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
922 }
923 if (m == 32 && n == 32 && k == 2 && b == 2)
924 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
925 if (m == 16 && n == 16 && k == 2 && b == 4)
926 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
927 if (m == 4 && n == 4 && k == 2 && b == 16)
928 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
929 if (m == 32 && n == 32 && k == 4 && b == 1)
930 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 8 && b == 1)
932 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
933 }
934
935 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
936 if (chipset >= kGfx950) {
937 if (m == 32 && n == 32 && k == 32 && b == 1)
938 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
939 if (m == 16 && n == 16 && k == 64 && b == 1)
940 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
941 }
942 if (m == 32 && n == 32 && k == 4 && b == 2)
943 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
944 if (m == 16 && n == 16 && k == 4 && b == 4)
945 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
946 if (m == 4 && n == 4 && k == 4 && b == 16)
947 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
948 if (m == 32 && n == 32 && k == 8 && b == 1)
949 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
950 if (m == 16 && n == 16 && k == 16 && b == 1)
951 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
952 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
953 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
954 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
955 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
956 }
957
958 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
959 if (m == 16 && n == 16 && k == 4 && b == 1)
960 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
961 if (m == 4 && n == 4 && k == 4 && b == 4)
962 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
963 }
964
965 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
966 // Known to be correct because there are no scalar f8 instructions and
967 // because a length mismatch will have been caught by the verifier.
968 Type sourceBElem =
969 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
970 if (m == 16 && n == 16 && k == 32 && b == 1) {
971 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
972 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
973 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
974 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
975 }
976 if (m == 32 && n == 32 && k == 16 && b == 1) {
977 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
978 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
979 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
980 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
981 }
982 }
983
984 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
985 Type sourceBElem =
986 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
987 if (m == 16 && n == 16 && k == 32 && b == 1) {
988 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
989 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
990 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
991 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
992 }
993 if (m == 32 && n == 32 && k == 16 && b == 1) {
994 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
995 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
996 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
997 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
998 }
999 }
1000
1001 return std::nullopt;
1002}
1003
1004static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1006 .Case([](Float8E4M3FNType) { return 0u; })
1007 .Case([](Float8E5M2Type) { return 1u; })
1008 .Case([](Float6E2M3FNType) { return 2u; })
1009 .Case([](Float6E3M2FNType) { return 3u; })
1010 .Case([](Float4E2M1FNType) { return 4u; })
1011 .Default(std::nullopt);
1012}
1013
1014/// If there is a scaled MFMA instruction for the input element types `aType`
1015/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1016/// blocks) on the given `chipset`, return a tuple consisting of the
1017/// OperationName of the intrinsic and the type codes that need to be passed to
1018/// that intrinsic. Note that this is also used to implement some un-scaled
1019/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1020/// MFMA with a scale of 0.
1021static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1022mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1023 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1024 aType = getElementTypeOrSelf(aType);
1025 bType = getElementTypeOrSelf(bType);
1026 destType = getElementTypeOrSelf(destType);
1027
1028 if (chipset < kGfx950)
1029 return std::nullopt;
1030 if (!isa<Float32Type>(destType))
1031 return std::nullopt;
1032
1033 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1034 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1035 if (!aTypeCode || !bTypeCode)
1036 return std::nullopt;
1037
1038 if (m == 32 && n == 32 && k == 64 && b == 1)
1039 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1040 *aTypeCode, *bTypeCode};
1041 if (m == 16 && n == 16 && k == 128 && b == 1)
1042 return std::tuple{
1043 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1044 *bTypeCode};
1045
1046 return std::nullopt;
1047}
1048
1049static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1050mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1052 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1053 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1054 mfma.getBlocks(), chipset);
1055}
1056
1057static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1058mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1059 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1060 smfma.getSourceB().getType(),
1061 smfma.getDestC().getType(), smfma.getM(),
1062 smfma.getN(), smfma.getK(), 1u, chipset);
1063}
1064
1065/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1066/// for RDNA3/4 architectures.
1067static std::optional<StringRef>
1068wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1069 Type elemDestType, uint32_t k, bool isRDNA3) {
1070 using fp8 = Float8E4M3FNType;
1071 using bf8 = Float8E5M2Type;
1072
1073 // Handle k == 16 for RDNA3/4.
1074 if (k == 16) {
1075 // Common patterns for RDNA3 and RDNA4.
1076 if (elemSourceType.isF16() && elemDestType.isF32())
1077 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1078 if (elemSourceType.isBF16() && elemDestType.isF32())
1079 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1080 if (elemSourceType.isF16() && elemDestType.isF16())
1081 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1082 if (elemSourceType.isBF16() && elemDestType.isBF16())
1083 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1084 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1085 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1086
1087 // RDNA3 specific patterns.
1088 if (isRDNA3) {
1089 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1090 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1092 }
1093
1094 // RDNA4 specific patterns (fp8/bf8).
1095 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1096 elemDestType.isF32())
1097 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1098 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1099 elemDestType.isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1101 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.isF32())
1103 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1105 elemDestType.isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1107 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1108 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1109
1110 return std::nullopt;
1111 }
1112
1113 // Handle k == 32 for RDNA4.
1114 if (k == 32 && !isRDNA3) {
1115 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1116 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1117 }
1118
1119 return std::nullopt;
1120}
1121
1122/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1123/// for the gfx1250 architecture.
1124static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1125 Type elemBSourceType,
1126 Type elemDestType,
1127 uint32_t k) {
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1130
1131 if (k == 4) {
1132 if (elemSourceType.isF32() && elemDestType.isF32())
1133 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1134
1135 return std::nullopt;
1136 }
1137
1138 if (k == 32) {
1139 if (elemSourceType.isF16() && elemDestType.isF32())
1140 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1141 if (elemSourceType.isBF16() && elemDestType.isF32())
1142 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1143 if (elemSourceType.isF16() && elemDestType.isF16())
1144 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1145 if (elemSourceType.isBF16() && elemDestType.isBF16())
1146 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1147
1148 return std::nullopt;
1149 }
1150
1151 if (k == 64) {
1152 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1153 if (elemDestType.isF32())
1154 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1155 if (elemDestType.isF16())
1156 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1157 }
1158 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1159 if (elemDestType.isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1161 if (elemDestType.isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1163 }
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.isF32())
1166 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1167 if (elemDestType.isF16())
1168 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1169 }
1170 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1171 if (elemDestType.isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1173 if (elemDestType.isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1175 }
1176 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1177 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1178
1179 return std::nullopt;
1180 }
1181
1182 if (k == 128) {
1183 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1184 if (elemDestType.isF32())
1185 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1186 if (elemDestType.isF16())
1187 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1188 }
1189 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1190 if (elemDestType.isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1192 if (elemDestType.isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1194 }
1195 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.isF32())
1197 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1198 if (elemDestType.isF16())
1199 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1200 }
1201 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1202 if (elemDestType.isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1204 if (elemDestType.isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1206 }
1207
1208 return std::nullopt;
1209 }
1210
1211 return std::nullopt;
1212}
1213
1214/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1215/// operation if one exists. This includes checking to ensure the intrinsic is
1216/// supported on the architecture you are compiling for.
1217static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1218 Chipset chipset) {
1219 bool isGfx950 = chipset >= kGfx950;
1220 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1221 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1222
1223 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1224 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1225 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1226 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1227
1228 if (m == 16 && n == 16 && k == 32) {
1229 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1230 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1231 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1232 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1233 }
1234
1235 if (m == 16 && n == 16 && k == 64) {
1236 if (isGfx950) {
1237 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1238 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1239 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1240 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1241 }
1242 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1243 destElem.isInteger(32))
1244 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1245 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1246 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1247 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1248 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1249 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1250 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1251 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1252 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1253 }
1254
1255 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1256 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1257 destElem.isInteger(32))
1258 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1259 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1260 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1261 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1262 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1263 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1264 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1265 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1266 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1267 }
1268
1269 if (m == 32 && n == 32 && k == 16) {
1270 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1271 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1272 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1273 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1274 }
1275
1276 if (m == 32 && n == 32 && k == 32) {
1277 if (isGfx950) {
1278 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1279 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1280 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1281 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1282 }
1283 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1284 destElem.isInteger(32))
1285 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1286 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1287 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1288 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1289 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1290 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1291 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1292 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1293 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1294 }
1295
1296 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1297 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1298 destElem.isInteger(32))
1299 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1300 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1301 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1302 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1303 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1304 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1305 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1306 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1307 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1308 }
1309
1310 return std::nullopt;
1311}
1312
1313/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1314/// if one exists. This includes checking to ensure the intrinsic is supported
1315/// on the architecture you are compiling for.
1316static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1317 Chipset chipset) {
1318 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1319 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1320 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1321 Type elemSourceType = sourceVectorType.getElementType();
1322 Type elemBSourceType = sourceBVectorType.getElementType();
1323 Type elemDestType = destVectorType.getElementType();
1324
1325 const uint32_t k = wmma.getK();
1326 const bool isRDNA3 = chipset.majorVersion == 11;
1327 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1328
1329 // Handle RDNA3 and RDNA4.
1330 if (isRDNA3 || isRDNA4)
1331 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1332 k, isRDNA3);
1333
1334 // Handle gfx1250.
1335 if (chipset == kGfx1250)
1336 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1337 elemDestType, k);
1338
1339 return std::nullopt;
1340}
1341
1342namespace {
1343struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1344 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1345 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1346
1347 Chipset chipset;
1348
1349 LogicalResult
1350 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1351 ConversionPatternRewriter &rewriter) const override {
1352 Location loc = op.getLoc();
1353 Type outType = typeConverter->convertType(op.getDestD().getType());
1354 Type intrinsicOutType = outType;
1355 if (auto outVecType = dyn_cast<VectorType>(outType))
1356 if (outVecType.getElementType().isBF16())
1357 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1358
1359 if (chipset.majorVersion != 9 || chipset < kGfx908)
1360 return op->emitOpError("MFMA only supported on gfx908+");
1361 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1362 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1363 if (chipset < kGfx942)
1364 return op.emitOpError("negation unsupported on older than gfx942");
1365 getBlgpField |=
1366 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1367 }
1368 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1369 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1370 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1371 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1372 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1373
1374 bool isScaled =
1375 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1376 if (isScaled &&
1377 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1378 return op.emitOpError(
1379 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1380 "be scaled as those fields are used for type information");
1381 }
1382
1383 StringRef intrinsicName =
1384 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1385 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1386 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1387 bool allowBf16 = [&]() {
1388 if (chipset < kGfx950)
1389 return false;
1390 if (isScaled)
1391 return true;
1392 return intrinsicName.contains("16x16x32.bf16") ||
1393 intrinsicName.contains("32x32x16.bf16");
1394 }();
1395 OperationState loweredOp(loc, intrinsicName);
1396 loweredOp.addTypes(intrinsicOutType);
1397 loweredOp.addOperands({packSmallFloatVectorOperand(
1398 rewriter, loc, adaptor.getSourceA(), allowBf16),
1400 rewriter, loc, adaptor.getSourceB(), allowBf16),
1401 adaptor.getDestC()});
1402 if (isScaled) {
1403 Value zero = createI32Constant(rewriter, loc, 0);
1404 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1405 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1406 loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1407 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1408 {"opselA", rewriter.getI32IntegerAttr(0)},
1409 {"opselB", rewriter.getI32IntegerAttr(0)}});
1410 } else {
1411 loweredOp.addAttributes(
1412 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1413 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1414 {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1415 };
1416 Value lowered = rewriter.create(loweredOp)->getResult(0);
1417 if (outType != intrinsicOutType)
1418 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1419 rewriter.replaceOp(op, lowered);
1420 return success();
1421 }
1422};
1423
1424struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1425 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1426 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1427
1428 Chipset chipset;
1429
1430 LogicalResult
1431 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1432 ConversionPatternRewriter &rewriter) const override {
1433 Location loc = op.getLoc();
1434 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1435
1436 if (chipset.majorVersion != 9 || chipset < kGfx950)
1437 return op->emitOpError("scaled MFMA only supported on gfx908+");
1438 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1439 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1440 if (!maybeScaledIntrinsic.has_value())
1441 return op.emitOpError(
1442 "no intrinsic matching scaled MFMA size on given chipset");
1443
1444 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1445 OperationState loweredOp(loc, intrinsicName);
1446 loweredOp.addTypes(intrinsicOutType);
1447 loweredOp.addOperands(
1448 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1449 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1450 adaptor.getDestC()});
1451 loweredOp.addOperands(
1452 {/*scales A*/
1453 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1454 /*scales B*/
1455 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1456 loweredOp.addAttributes(
1457 {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1458 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1459 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1460 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1461
1462 Value lowered = rewriter.create(loweredOp)->getResult(0);
1463 rewriter.replaceOp(op, lowered);
1464 return success();
1465 }
1466};
1467
1468struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1469 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1470 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1471
1472 Chipset chipset;
1473
1474 LogicalResult
1475 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1476 ConversionPatternRewriter &rewriter) const override {
1477 Location loc = op.getLoc();
1478 auto outType =
1479 typeConverter->convertType<VectorType>(op.getDestC().getType());
1480 if (!outType)
1481 return rewriter.notifyMatchFailure(op, "type conversion failed");
1482
1483 // smfmac is supported on gfx942 and gfx950.
1484 if (chipset.majorVersion != 9 || chipset < kGfx942)
1485 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1486 bool isGfx950 = chipset >= kGfx950;
1487
1488 Value a = convertSparseMFMAVectorOperand(rewriter, loc,
1489 adaptor.getSourceA(), isGfx950);
1490 Value b = convertSparseMFMAVectorOperand(rewriter, loc,
1491 adaptor.getSourceB(), isGfx950);
1492 Value c = adaptor.getDestC();
1493
1494 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1495 if (!maybeIntrinsic.has_value())
1496 return op.emitOpError(
1497 "no intrinsic matching sparse MFMA on the given chipset");
1498
1499 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1500 Value sparseIdx = LLVM::BitcastOp::create(
1501 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1502
1503 OperationState loweredOp(loc, maybeIntrinsic.value());
1504 loweredOp.addTypes(outType);
1505 loweredOp.addOperands({a, b, c, sparseIdx});
1506 loweredOp.addAttributes(
1507 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1508 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1509 Value lowered = rewriter.create(loweredOp)->getResult(0);
1510 rewriter.replaceOp(op, lowered);
1511 return success();
1512 }
1513};
1514
1515struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1516 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1517 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1518
1519 Chipset chipset;
1520
1521 LogicalResult
1522 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1523 ConversionPatternRewriter &rewriter) const override {
1524 Location loc = op.getLoc();
1525 auto outType =
1526 typeConverter->convertType<VectorType>(op.getDestD().getType());
1527 if (!outType)
1528 return rewriter.notifyMatchFailure(op, "type conversion failed");
1529
1530 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1531 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1532
1533 bool isGFX1250 = chipset >= kGfx1250;
1534
1535 // The WMMA operations represent vectors of bf16s as vectors of i16s
1536 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1537 // bitcast them back.
1538 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1539 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1540 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1541 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1542 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1543 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1544 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1545 VectorType rawOutType = outType;
1546 if (castOutToI16)
1547 rawOutType = outType.clone(rewriter.getI16Type());
1548 Value a = adaptor.getSourceA();
1549 if (castAToI16)
1550 a = LLVM::BitcastOp::create(rewriter, loc,
1551 aType.clone(rewriter.getI16Type()), a);
1552 Value b = adaptor.getSourceB();
1553 if (castBToI16)
1554 b = LLVM::BitcastOp::create(rewriter, loc,
1555 bType.clone(rewriter.getI16Type()), b);
1556 Value destC = adaptor.getDestC();
1557 if (castDestCToI16)
1558 destC = LLVM::BitcastOp::create(
1559 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1560
1561 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1562
1563 if (!maybeIntrinsic.has_value())
1564 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1565
1566 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1567 return op.emitOpError("subwordOffset not supported on gfx12+");
1568
1569 SmallVector<Value, 4> operands;
1570 SmallVector<NamedAttribute, 4> attrs;
1571 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1572 op.getSourceA(), operands, attrs, "signA");
1573 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1574 op.getSourceB(), operands, attrs, "signB");
1575 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1576 op.getSubwordOffset(), op.getClamp(), operands,
1577 attrs);
1578
1579 OperationState loweredOp(loc, *maybeIntrinsic);
1580 loweredOp.addTypes(rawOutType);
1581 loweredOp.addOperands(operands);
1582 loweredOp.addAttributes(attrs);
1583 Operation *lowered = rewriter.create(loweredOp);
1584
1585 Operation *maybeCastBack = lowered;
1586 if (rawOutType != outType)
1587 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1588 lowered->getResult(0));
1589 rewriter.replaceOp(op, maybeCastBack->getResults());
1590
1591 return success();
1592 }
1593};
1594
1595struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1596 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1597 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1598
1599 Chipset chipset;
1600
1601 LogicalResult
1602 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1603 ConversionPatternRewriter &rewriter) const override {
1604 Location loc = op.getLoc();
1605 auto outType =
1606 typeConverter->convertType<VectorType>(op.getDestD().getType());
1607 if (!outType)
1608 return rewriter.notifyMatchFailure(op, "type conversion failed");
1609
1610 if (chipset < kGfx1250)
1611 return op->emitOpError("WMMA scale only supported on gfx1250+");
1612
1613 int64_t m = op.getM();
1614 int64_t n = op.getN();
1615 int64_t k = op.getK();
1616
1617 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1618 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1619
1620 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1621 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1622
1623 if (!aFmtCode || !bFmtCode)
1624 return op.emitOpError("unsupported element types for scaled_wmma");
1625
1626 // Get scale vector types and determine variant (scale vs scale16).
1627 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1628 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1629
1630 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1631 return op.emitOpError("scaleA and scaleB must have equal vector length");
1632
1633 // Extract scale format from element types.
1634 Type scaleAElemType = scaleAVecType.getElementType();
1635 Type scaleBElemType = scaleBVecType.getElementType();
1636
1637 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1638 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1639
1640 if (!scaleAFmt || !scaleBFmt)
1641 return op.emitOpError("unsupported scale element types");
1642
1643 // Determine which intrinsic to use based on dimensions.
1644 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1645 std::optional<StringRef> intrinsicName =
1646 getScaledWmmaIntrinsicName(m, n, k, isScale16);
1647 if (!intrinsicName)
1648 return op.emitOpError("unsupported scaled_wmma dimensions: ")
1649 << m << "x" << n << "x" << k;
1650
1651 SmallVector<NamedAttribute, 8> attrs;
1652
1653 // The f4 variant does not have fmtA and fmtB attributes.
1654 bool is32x16 = (m == 32 && n == 16 && k == 128);
1655 if (!is32x16) {
1656 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1657 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1658 }
1659
1660 // modC uses default value of 0.
1661 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1662
1663 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1664 // half of the wave that is being selected (0 or 1).
1665 attrs.emplace_back(
1666 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1667 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1668 attrs.emplace_back(
1669 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1670 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1671
1672 // Reuse flags use default value of false.
1673 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1674 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1675
1676 // Convert typed float vectors to packed format.
1677 Value sourceA =
1678 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1679 Value sourceB =
1680 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1681
1682 // Pack scale vectors into i32/i64.
1683 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1684 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1685
1686 // Create the intrinsic call.
1687 OperationState loweredOp(loc, *intrinsicName);
1688 loweredOp.addTypes(outType);
1689 loweredOp.addOperands(
1690 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1691 loweredOp.addAttributes(attrs);
1692
1693 Operation *lowered = rewriter.create(loweredOp);
1694 rewriter.replaceOp(op, lowered->getResults());
1695
1696 return success();
1697 }
1698};
1699
1700struct TransposeLoadOpLowering
1701 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1702 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1703 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1704
1705 Chipset chipset;
1706
1707 LogicalResult
1708 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1709 ConversionPatternRewriter &rewriter) const override {
1710 if (chipset != kGfx950)
1711 return op.emitOpError("Non-gfx950 chipset not supported");
1712
1713 Location loc = op.getLoc();
1714 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1715
1716 // Elements in subbyte memrefs are stored non-contiguously,
1717 // reject if source is sub-byte memref. Use emulated memrefs instead.
1718 size_t srcElementSize =
1719 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1720 if (srcElementSize < 8)
1721 return op.emitOpError("Expect source memref to have at least 8 bits "
1722 "element size, got ")
1723 << srcElementSize;
1724
1725 auto resultType = cast<VectorType>(op.getResult().getType());
1726 Value srcPtr =
1727 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1728 (adaptor.getSrcIndices()));
1729
1730 size_t numElements = resultType.getNumElements();
1731 size_t elementTypeSize =
1732 resultType.getElementType().getIntOrFloatBitWidth();
1733
1734 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1735 // the element size is smaller than 16 bits.
1736 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1737 rewriter.getIntegerType(32));
1738 Type llvmResultType = typeConverter->convertType(resultType);
1739
1740 switch (elementTypeSize) {
1741 case 4: {
1742 assert(numElements == 16);
1743 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1744 rocdlResultType, srcPtr);
1745 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1746 break;
1747 }
1748 case 6: {
1749 assert(numElements == 16);
1750 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1751 rocdlResultType, srcPtr);
1752 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1753 break;
1754 }
1755 case 8: {
1756 assert(numElements == 8);
1757 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1758 rocdlResultType, srcPtr);
1759 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1760 break;
1761 }
1762 case 16: {
1763 assert(numElements == 4);
1764 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1765 srcPtr);
1766 break;
1767 }
1768 default:
1769 return op.emitOpError("Unsupported element size for transpose load");
1770 }
1771 return success();
1772 }
1773};
1774
1775struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1776 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1777 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1778
1779 Chipset chipset;
1780
1781 LogicalResult
1782 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1783 ConversionPatternRewriter &rewriter) const override {
1784 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1785 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1786
1787 Location loc = op.getLoc();
1788
1789 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1790 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1791
1792 // TODO: instead of only transfering one element per thread, we could
1793 // augment it to transfer multiple elements per thread by issuing multiple
1794 // `global_load_lds` instructions.
1795 Type transferType = op.getTransferType();
1796 int loadWidth = [&]() -> int {
1797 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1798 return (transferVectorType.getNumElements() *
1799 transferVectorType.getElementTypeBitWidth()) /
1800 8;
1801 }
1802 return transferType.getIntOrFloatBitWidth() / 8;
1803 }();
1804
1805 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1806 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1807 return op.emitOpError("chipset unsupported element size");
1808
1809 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1810 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1811 "16-byte load widths are only supported on gfx950");
1812
1813 Value srcPtr =
1814 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1815 (adaptor.getSrcIndices()));
1816 Value dstPtr =
1817 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1818 (adaptor.getDstIndices()));
1819
1820 if (op.getAsync()) {
1821 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
1822 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1823 /*offset=*/rewriter.getI32IntegerAttr(0),
1824 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1825 ArrayAttr{});
1826 } else {
1827 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1828 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1829 /*offset=*/rewriter.getI32IntegerAttr(0),
1830 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1831 ArrayAttr{});
1832 }
1833
1834 return success();
1835 }
1836};
1837
1838namespace {
1839struct ExtPackedFp8OpLowering final
1840 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1841 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1842 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1843 chipset(chipset) {}
1844 Chipset chipset;
1845
1846 LogicalResult
1847 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1848 ConversionPatternRewriter &rewriter) const override;
1849};
1850
1851struct ScaledExtPackedMatrixOpLowering final
1852 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1853 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1854 Chipset chipset)
1855 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1856 chipset(chipset) {}
1857 Chipset chipset;
1858
1859 LogicalResult
1860 matchAndRewrite(ScaledExtPackedMatrixOp op,
1861 ScaledExtPackedMatrixOpAdaptor adaptor,
1862 ConversionPatternRewriter &rewriter) const override;
1863};
1864
1865struct PackedTrunc2xFp8OpLowering final
1866 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1867 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1868 Chipset chipset)
1869 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1870 chipset(chipset) {}
1871 Chipset chipset;
1872
1873 LogicalResult
1874 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1875 ConversionPatternRewriter &rewriter) const override;
1876};
1877
1878struct PackedStochRoundFp8OpLowering final
1879 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1880 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1881 Chipset chipset)
1882 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1883 chipset(chipset) {}
1884 Chipset chipset;
1885
1886 LogicalResult
1887 matchAndRewrite(PackedStochRoundFp8Op op,
1888 PackedStochRoundFp8OpAdaptor adaptor,
1889 ConversionPatternRewriter &rewriter) const override;
1890};
1891
1892struct ScaledExtPackedOpLowering final
1893 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1894 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1895 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1896 chipset(chipset) {}
1897 Chipset chipset;
1898
1899 LogicalResult
1900 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1901 ConversionPatternRewriter &rewriter) const override;
1902};
1903
1904struct PackedScaledTruncOpLowering final
1905 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1906 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1907 Chipset chipset)
1908 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1909 chipset(chipset) {}
1910 Chipset chipset;
1911
1912 LogicalResult
1913 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1914 ConversionPatternRewriter &rewriter) const override;
1915};
1916
1917} // end namespace
1918
1919LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1920 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1921 ConversionPatternRewriter &rewriter) const {
1922 Location loc = op.getLoc();
1923 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1924 return rewriter.notifyMatchFailure(
1925 loc, "Fp8 conversion instructions are not available on target "
1926 "architecture and their emulation is not implemented");
1927 Type v4i8 =
1928 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1929 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1930 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1931
1932 Value source = adaptor.getSource();
1933 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1934 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1935 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1936 // Extend to a v4i8
1937 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1938 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1939 if (!sourceVecType) {
1940 longVec = LLVM::InsertElementOp::create(
1941 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1942 } else {
1943 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1944 Value idx = createI32Constant(rewriter, loc, i);
1945 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1946 longVec =
1947 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1948 }
1949 }
1950 source = longVec;
1951 }
1952 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1953 if (resultVecType) {
1954 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1955 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1956 op.getIndex());
1957 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1958 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1959 op.getIndex());
1960 }
1961 } else {
1962 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1963 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1964 op.getIndex());
1965 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1966 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1967 op.getIndex());
1968 }
1969 }
1970 return success();
1971}
1972
1973int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1974 int32_t firstScaleByte) {
1975 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1976 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
1977 // firstScaleByte are merged into a single attribute scaleSel. This is how
1978 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1979 // attribute but is derifed from firstScaleLane).
1980 assert(llvm::is_contained({16, 32}, blockSize));
1981 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1982
1983 const bool isFp8 = bitWidth == 8;
1984 const bool isBlock16 = blockSize == 16;
1985
1986 if (!isFp8) {
1987 int32_t bit0 = isBlock16;
1988 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1989 int32_t bit1 = (firstScaleByte == 2) << 1;
1990 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1991 int32_t bit2 = scaleWaveHalf << 2;
1992 return bit2 | bit1 | bit0;
1993 }
1994
1995 int32_t bit0 = isBlock16;
1996 // firstScaleByte is guaranteed to be defined by two bits.
1997 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1998 int32_t bits2and1 = firstScaleByte << 1;
1999 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2000 int32_t bit3 = scaleWaveHalf << 3;
2001 int32_t bits = bit3 | bits2and1 | bit0;
2002 // These are invalid cases.
2003 assert(!llvm::is_contained(
2004 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2005 return bits;
2006}
2007
2008static std::optional<StringRef>
2009scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2010 using fp4 = Float4E2M1FNType;
2011 using fp8 = Float8E4M3FNType;
2012 using bf8 = Float8E5M2Type;
2013 using fp6 = Float6E2M3FNType;
2014 using bf6 = Float6E3M2FNType;
2015 if (isa<fp4>(srcElemType)) {
2016 if (destElemType.isF16())
2017 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2018 if (destElemType.isBF16())
2019 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2020 if (destElemType.isF32())
2021 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2022 return std::nullopt;
2023 }
2024 if (isa<fp8>(srcElemType)) {
2025 if (destElemType.isF16())
2026 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2027 if (destElemType.isBF16())
2028 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2029 if (destElemType.isF32())
2030 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2031 return std::nullopt;
2032 }
2033 if (isa<bf8>(srcElemType)) {
2034 if (destElemType.isF16())
2035 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2036 if (destElemType.isBF16())
2037 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2038 if (destElemType.isF32())
2039 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2040 return std::nullopt;
2041 }
2042 if (isa<fp6>(srcElemType)) {
2043 if (destElemType.isF16())
2044 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2045 if (destElemType.isBF16())
2046 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2047 if (destElemType.isF32())
2048 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2049 return std::nullopt;
2050 }
2051 if (isa<bf6>(srcElemType)) {
2052 if (destElemType.isF16())
2053 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2054 if (destElemType.isBF16())
2055 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2056 if (destElemType.isF32())
2057 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2058 return std::nullopt;
2059 }
2060 llvm_unreachable("invalid combination of element types for packed conversion "
2061 "instructions");
2062}
2063
2064LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2065 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2066 ConversionPatternRewriter &rewriter) const {
2067 using fp4 = Float4E2M1FNType;
2068 using fp8 = Float8E4M3FNType;
2069 using bf8 = Float8E5M2Type;
2070 using fp6 = Float6E2M3FNType;
2071 using bf6 = Float6E3M2FNType;
2072 Location loc = op.getLoc();
2073 if (chipset != kGfx1250) {
2074 return rewriter.notifyMatchFailure(
2075 loc,
2076 "Scaled fp packed conversion instructions are not available on target "
2077 "architecture and their emulation is not implemented");
2078 }
2079 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2080 // is being selected.
2081 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2082 int32_t firstScaleByte = op.getFirstScaleByte();
2083 int32_t blockSize = op.getBlockSize();
2084 auto sourceType = cast<VectorType>(op.getSource().getType());
2085 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2086 unsigned bitWidth = srcElemType.getWidth();
2087
2088 auto targetType = cast<VectorType>(op.getResult().getType());
2089 auto destElemType = cast<FloatType>(targetType.getElementType());
2090
2091 IntegerType i32 = rewriter.getI32Type();
2092 Value source = adaptor.getSource();
2093 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2094 Type packedType = nullptr;
2095 if (isa<fp4>(srcElemType)) {
2096 packedType = i32;
2097 packedType = getTypeConverter()->convertType(packedType);
2098 } else if (isa<fp8, bf8>(srcElemType)) {
2099 packedType = VectorType::get(2, i32);
2100 packedType = getTypeConverter()->convertType(packedType);
2101 } else if (isa<fp6, bf6>(srcElemType)) {
2102 packedType = VectorType::get(3, i32);
2103 packedType = getTypeConverter()->convertType(packedType);
2104 } else {
2105 llvm_unreachable("invalid element type for packed scaled ext");
2106 }
2107
2108 if (!packedType || !llvmResultType) {
2109 return rewriter.notifyMatchFailure(op, "type conversion failed");
2110 }
2111
2112 std::optional<StringRef> maybeIntrinsic =
2113 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2114 if (!maybeIntrinsic.has_value())
2115 return op.emitOpError(
2116 "no intrinsic matching packed scaled conversion on the given chipset");
2117
2118 int32_t scaleSel =
2119 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2120 Value castedScale =
2121 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2122 Value castedSource =
2123 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2124
2125 OperationState loweredOp(loc, *maybeIntrinsic);
2126 loweredOp.addTypes({llvmResultType});
2127 loweredOp.addOperands({castedSource, castedScale});
2128
2129 SmallVector<NamedAttribute, 1> attrs;
2130 attrs.push_back(
2131 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2132
2133 loweredOp.addAttributes(attrs);
2134 Operation *lowered = rewriter.create(loweredOp);
2135 rewriter.replaceOp(op, lowered);
2136
2137 return success();
2138}
2139
2140LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2141 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2142 ConversionPatternRewriter &rewriter) const {
2143 Location loc = op.getLoc();
2144 if (chipset != kGfx950)
2145 return rewriter.notifyMatchFailure(
2146 loc, "Scaled fp conversion instructions are not available on target "
2147 "architecture and their emulation is not implemented");
2148 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2149
2150 Value source = adaptor.getSource();
2151 Value scale = adaptor.getScale();
2152
2153 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2154 Type sourceElemType = sourceVecType.getElementType();
2155 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2156 Type destElemType = destVecType.getElementType();
2157
2158 VectorType packedVecType;
2159 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2160 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2161 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2162 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2163 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2164 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2165 } else {
2166 llvm_unreachable("invalid element type for scaled ext");
2167 }
2168
2169 // Extend to a packedVectorType
2170 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2171 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2172 if (!sourceVecType) {
2173 longVec = LLVM::InsertElementOp::create(
2174 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2175 } else {
2176 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2177 Value idx = createI32Constant(rewriter, loc, i);
2178 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2179 longVec =
2180 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2181 }
2182 }
2183 source = longVec;
2184 }
2185 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2186
2187 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2188 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2189 op, destVecType, i32Source, scale, op.getIndex());
2190 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2191 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2192 op, destVecType, i32Source, scale, op.getIndex());
2193 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2194 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2195 op, destVecType, i32Source, scale, op.getIndex());
2196 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2197 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2198 op, destVecType, i32Source, scale, op.getIndex());
2199 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2200 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2201 op, destVecType, i32Source, scale, op.getIndex());
2202 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2203 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2204 op, destVecType, i32Source, scale, op.getIndex());
2205 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2206 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2207 op, destVecType, i32Source, scale, op.getIndex());
2208 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2209 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2210 op, destVecType, i32Source, scale, op.getIndex());
2211 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2212 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2213 op, destVecType, i32Source, scale, op.getIndex());
2214 else
2215 return failure();
2216
2217 return success();
2218}
2219
2220LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2221 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2222 ConversionPatternRewriter &rewriter) const {
2223 Location loc = op.getLoc();
2224 if (chipset != kGfx950)
2225 return rewriter.notifyMatchFailure(
2226 loc, "Scaled fp conversion instructions are not available on target "
2227 "architecture and their emulation is not implemented");
2228 Type v2i16 = getTypeConverter()->convertType(
2229 VectorType::get(2, rewriter.getI16Type()));
2230 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2231
2232 Type resultType = op.getResult().getType();
2233 Type resultElemType = getElementTypeOrSelf(resultType);
2234 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2235 Type sourceElemType = sourceVecType.getElementType();
2236
2237 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2238
2239 Value source = adaptor.getSource();
2240 Value scale = adaptor.getScale();
2241 Value existing = adaptor.getExisting();
2242 if (existing)
2243 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2244 else
2245 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2246
2247 if (sourceVecType.getNumElements() < 2) {
2248 Value c0 = createI32Constant(rewriter, loc, 0);
2249 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2250 VectorType v2 = VectorType::get(2, sourceElemType);
2251 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2252 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2253 }
2254
2255 Value sourceA, sourceB;
2256 if (sourceElemType.isF32()) {
2257 Value c0 = createI32Constant(rewriter, loc, 0);
2258 Value c1 = createI32Constant(rewriter, loc, 1);
2259 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2260 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2261 }
2262
2263 Value result;
2264 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2265 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2266 existing, sourceA, sourceB,
2267 scale, op.getIndex());
2268 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2269 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2270 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2271 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2272 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2273 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2274 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2275 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2276 existing, sourceA, sourceB,
2277 scale, op.getIndex());
2278 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2279 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2280 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2281 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2282 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2283 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2284 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2285 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2286 existing, sourceA, sourceB,
2287 scale, op.getIndex());
2288 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2289 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2290 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2291 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2292 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2293 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2294 else
2295 return failure();
2296
2297 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2298 op, getTypeConverter()->convertType(resultType), result);
2299 return success();
2300}
2301
2302LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2303 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2304 ConversionPatternRewriter &rewriter) const {
2305 Location loc = op.getLoc();
2306 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2307 return rewriter.notifyMatchFailure(
2308 loc, "Fp8 conversion instructions are not available on target "
2309 "architecture and their emulation is not implemented");
2310 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2311
2312 Type resultType = op.getResult().getType();
2313 Type resultElemType = getElementTypeOrSelf(resultType);
2314
2315 Value sourceA = adaptor.getSourceA();
2316 Value sourceB = adaptor.getSourceB();
2317 if (!sourceB)
2318 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2319 Value existing = adaptor.getExisting();
2320 if (existing)
2321 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2322 else
2323 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2324
2325 Value result;
2326 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2327 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2328 existing, op.getWordIndex());
2329 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2330 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2331 existing, op.getWordIndex());
2332
2333 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2334 op, getTypeConverter()->convertType(resultType), result);
2335 return success();
2336}
2337
2338LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2339 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2340 ConversionPatternRewriter &rewriter) const {
2341 Location loc = op.getLoc();
2342 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2343 return rewriter.notifyMatchFailure(
2344 loc, "Fp8 conversion instructions are not available on target "
2345 "architecture and their emulation is not implemented");
2346 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2347
2348 Type resultType = op.getResult().getType();
2349 Type resultElemType = getElementTypeOrSelf(resultType);
2350
2351 Value source = adaptor.getSource();
2352 Value stoch = adaptor.getStochiasticParam();
2353 Value existing = adaptor.getExisting();
2354 if (existing)
2355 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2356 else
2357 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2358
2359 Value result;
2360 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2361 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2362 existing, op.getStoreIndex());
2363 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2364 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2365 existing, op.getStoreIndex());
2366
2367 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2368 op, getTypeConverter()->convertType(resultType), result);
2369 return success();
2370}
2371
2372// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2373// operation into the corresponding ROCDL instructions.
2374struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2375 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2376 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2377 Chipset chipset;
2378
2379 LogicalResult
2380 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2381 ConversionPatternRewriter &rewriter) const override {
2382
2383 // Convert the source operand to the corresponding LLVM type
2384 Location loc = DppOp.getLoc();
2385 Value src = adaptor.getSrc();
2386 Value old = adaptor.getOld();
2387 Type srcType = src.getType();
2388 Type oldType = old.getType();
2389 Type llvmType = nullptr;
2390 if (srcType.getIntOrFloatBitWidth() < 32) {
2391 llvmType = rewriter.getI32Type();
2392 } else if (isa<FloatType>(srcType)) {
2393 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2394 ? rewriter.getF32Type()
2395 : rewriter.getF64Type();
2396 } else if (isa<IntegerType>(srcType)) {
2397 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2398 ? rewriter.getI32Type()
2399 : rewriter.getI64Type();
2400 }
2401 auto llvmSrcIntType = typeConverter->convertType(
2402 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2403
2404 // If the source type is less of 32, use bitcast to convert it to i32.
2405 auto convertOperand = [&](Value operand, Type operandType) {
2406 if (operandType.getIntOrFloatBitWidth() <= 16) {
2407 if (llvm::isa<FloatType>(operandType)) {
2408 operand =
2409 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2410 }
2411 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2412 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2413 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2414 operand =
2415 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2416 createI32Constant(rewriter, loc, 0));
2417 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2418 }
2419 return operand;
2420 };
2421
2422 src = convertOperand(src, srcType);
2423 old = convertOperand(old, oldType);
2424
2425 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2426 enum DppCtrl : unsigned {
2427 ROW_SHL0 = 0x100,
2428 ROW_SHR0 = 0x110,
2429 ROW_ROR0 = 0x120,
2430 WAVE_SHL1 = 0x130,
2431 WAVE_ROL1 = 0x134,
2432 WAVE_SHR1 = 0x138,
2433 WAVE_ROR1 = 0x13C,
2434 ROW_MIRROR = 0x140,
2435 ROW_HALF_MIRROR = 0x141,
2436 BCAST15 = 0x142,
2437 BCAST31 = 0x143,
2438 };
2439
2440 auto kind = DppOp.getKind();
2441 auto permArgument = DppOp.getPermArgument();
2442 uint32_t DppCtrl = 0;
2443
2444 switch (kind) {
2445
2446 case DPPPerm::quad_perm: {
2447 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2448 int32_t i = 0;
2449 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2450 uint32_t num = elem.getInt();
2451 DppCtrl |= num << (i * 2);
2452 i++;
2453 }
2454 break;
2455 }
2456 case DPPPerm::row_shl: {
2457 auto intAttr = cast<IntegerAttr>(*permArgument);
2458 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2459 break;
2460 }
2461 case DPPPerm::row_shr: {
2462 auto intAttr = cast<IntegerAttr>(*permArgument);
2463 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2464 break;
2465 }
2466 case DPPPerm::row_ror: {
2467 auto intAttr = cast<IntegerAttr>(*permArgument);
2468 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2469 break;
2470 }
2471 case DPPPerm::wave_shl:
2472 DppCtrl = DppCtrl::WAVE_SHL1;
2473 break;
2474 case DPPPerm::wave_shr:
2475 DppCtrl = DppCtrl::WAVE_SHR1;
2476 break;
2477 case DPPPerm::wave_rol:
2478 DppCtrl = DppCtrl::WAVE_ROL1;
2479 break;
2480 case DPPPerm::wave_ror:
2481 DppCtrl = DppCtrl::WAVE_ROR1;
2482 break;
2483 case DPPPerm::row_mirror:
2484 DppCtrl = DppCtrl::ROW_MIRROR;
2485 break;
2486 case DPPPerm::row_half_mirror:
2487 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2488 break;
2489 case DPPPerm::row_bcast_15:
2490 DppCtrl = DppCtrl::BCAST15;
2491 break;
2492 case DPPPerm::row_bcast_31:
2493 DppCtrl = DppCtrl::BCAST31;
2494 break;
2495 }
2496
2497 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2498 // constants
2499 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2500 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2501 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2502
2503 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2504 auto dppMovOp =
2505 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2506 rowMask, bankMask, boundCtrl);
2507
2508 Value result = dppMovOp.getRes();
2509 if (srcType.getIntOrFloatBitWidth() < 32) {
2510 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2511 if (!llvm::isa<IntegerType>(srcType)) {
2512 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2513 }
2514 }
2515
2516 // We are replacing the AMDGPU_DPPOp instruction with the new
2517 // ROCDL_DPPMovOp instruction
2518 rewriter.replaceOp(DppOp, ValueRange(result));
2519 return success();
2520 }
2521};
2522
2523struct AMDGPUSwizzleBitModeLowering
2524 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2526
2527 LogicalResult
2528 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2529 ConversionPatternRewriter &rewriter) const override {
2530 Location loc = op.getLoc();
2531 Type i32 = rewriter.getI32Type();
2532 Value src = adaptor.getSrc();
2533 SmallVector<Value> decomposed =
2534 LLVM::decomposeValue(rewriter, loc, src, i32);
2535 unsigned andMask = op.getAndMask();
2536 unsigned orMask = op.getOrMask();
2537 unsigned xorMask = op.getXorMask();
2538
2539 // bit 15 is 0 for the BitMode swizzle.
2540 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2541 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2542 Value maskValue = createI32Constant(rewriter, loc, mask);
2543 SmallVector<Value> swizzled;
2544 for (Value v : decomposed) {
2545 Value res =
2546 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2547 swizzled.emplace_back(res);
2548 }
2549
2550 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2551 rewriter.replaceOp(op, result);
2552 return success();
2553 }
2554};
2555
2556struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2558
2559 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2560 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2561 Chipset chipset;
2562
2563 LogicalResult
2564 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2565 ConversionPatternRewriter &rewriter) const override {
2566 if (chipset < kGfx950)
2567 return op->emitOpError("permlane_swap is only supported on gfx950+");
2568
2569 Location loc = op.getLoc();
2570 Type i32 = rewriter.getI32Type();
2571 Value src = adaptor.getSrc();
2572 unsigned rowLength = op.getRowLength();
2573 bool fi = op.getFetchInactive();
2574 bool boundctrl = op.getBoundCtrl();
2575
2576 SmallVector<Value> decomposed =
2577 LLVM::decomposeValue(rewriter, loc, src, i32);
2578
2579 SmallVector<Value> permuted;
2580 for (Value v : decomposed) {
2581 Value res;
2582 Type i32pair = LLVM::LLVMStructType::getLiteral(
2583 rewriter.getContext(), {v.getType(), v.getType()});
2584
2585 if (rowLength == 16)
2586 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2587 boundctrl);
2588 else if (rowLength == 32)
2589 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2590 boundctrl);
2591 else
2592 llvm_unreachable("unsupported row length");
2593
2594 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2595 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2596
2597 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2598 LLVM::ICmpPredicate::eq, vdst0, v);
2599
2600 // Per `permlane(16|32)` semantics: if the first extracted element equals
2601 // 'v', the result is the second element; otherwise it is the first.
2602 Value vdstNew =
2603 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2604 permuted.emplace_back(vdstNew);
2605 }
2606
2607 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2608 rewriter.replaceOp(op, result);
2609 return success();
2610 }
2611};
2612
2613//===----------------------------------------------------------------------===//
2614// In-LDS Barrier Operations
2615//===----------------------------------------------------------------------===//
2616
2617// Bit layout of ds_barrier_state (as i64):
2618// [63:32] init count (32 bits)
2619// [31:29] phase (3 bits)
2620// [28:0] pending count (29 bits)
2621constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
2622constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
2623constexpr int32_t kDsBarrierInitCountPos = 32;
2624constexpr int32_t kDsBarrierPendingCountMask =
2625 (1 << kDsBarrierPendingCountBitWidth) - 1;
2626
2627struct DsBarrierInitOpLowering
2628 : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
2629 Chipset chipset;
2630
2631 DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2632 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
2633
2634 LogicalResult
2635 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
2636 ConversionPatternRewriter &rewriter) const override {
2637 if (chipset < kGfx1250)
2638 return op->emitOpError("only supported on gfx1250+");
2639
2640 Location loc = op.getLoc();
2641 Type i64 = rewriter.getI64Type();
2642
2643 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2644 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2645 adaptor.getBase(), adaptor.getIndices());
2646
2647 // Note: We give participants as the number of arrivals that have to occur
2648 // before the phase changes. Hardware changes the phase when updating the
2649 // pending count would underflow, so we subtract 1 to get the behavior we're
2650 // looking for.
2651 Value initCount =
2652 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
2653 createI32Constant(rewriter, loc, 1));
2654
2655 // Just a bit of paranoia, but this also allows for configurable width if
2656 // that becomes a thing.
2657 Value countMask =
2658 createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
2659 Value maskedCount32 =
2660 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
2661 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
2662
2663 Value initCountShifted = LLVM::ShlOp::create(
2664 rewriter, loc, maskedCount,
2665 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
2666 Value barrierState =
2667 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
2668
2669 LLVM::StoreOp::create(
2670 rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
2671 /*isNonTemporal=*/false,
2672 /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
2673 /*syncscope=*/"workgroup");
2674
2675 rewriter.eraseOp(op);
2676 return success();
2677 }
2678};
2679
2680struct DsBarrierPollStateOpLowering
2681 : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
2682 Chipset chipset;
2683
2684 DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
2685 Chipset chipset)
2686 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
2687 chipset(chipset) {}
2688
2689 LogicalResult
2690 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
2691 ConversionPatternRewriter &rewriter) const override {
2692 if (chipset < kGfx1250)
2693 return op->emitOpError("only supported on gfx1250+");
2694
2695 Location loc = op.getLoc();
2696 Type i64 = rewriter.getI64Type();
2697
2698 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2699 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2700 adaptor.getBase(), adaptor.getIndices());
2701
2702 // Atomic load with workgroup scope and acquire ordering should be what
2703 // we're looking for.
2704 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
2705 op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
2706 /*nontemporal=*/false, /*invariant=*/false,
2707 /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
2708 /*syncscope=*/"workgroup");
2709 return success();
2710 }
2711};
2712
2713struct DsAsyncBarrierArriveOpLowering
2714 : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
2715 Chipset chipset;
2716
2717 DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
2718 Chipset chipset)
2719 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
2720 chipset(chipset) {}
2721
2722 LogicalResult
2723 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
2724 ConversionPatternRewriter &rewriter) const override {
2725 if (chipset < kGfx1250)
2726 return op->emitOpError("only supported on gfx1250+");
2727
2728 Location loc = op.getLoc();
2729
2730 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2731 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2732 adaptor.getBase(), adaptor.getIndices());
2733
2734 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
2735 op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
2736 /*tbaa=*/nullptr);
2737 return success();
2738 }
2739};
2740
2741struct DsBarrierArriveOpLowering
2742 : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
2743 Chipset chipset;
2744
2745 DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2746 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
2747 }
2748
2749 LogicalResult
2750 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
2751 ConversionPatternRewriter &rewriter) const override {
2752 if (chipset < kGfx1250)
2753 return op->emitOpError("only supported on gfx1250+");
2754
2755 Location loc = op.getLoc();
2756 Type i64 = rewriter.getI64Type();
2757
2758 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2759 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2760 adaptor.getBase(), adaptor.getIndices());
2761
2762 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
2763 op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
2764 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
2765 return success();
2766 }
2767};
2768
2769struct DsBarrierStatePhaseOpLowering
2770 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
2772
2773 LogicalResult
2774 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
2775 ConversionPatternRewriter &rewriter) const override {
2776 Location loc = op.getLoc();
2777 Type i32 = rewriter.getI32Type();
2778
2779 Value state = adaptor.getState();
2780
2781 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
2782 Value phase = LLVM::LShrOp::create(
2783 rewriter, loc, noInitCount,
2784 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
2785
2786 rewriter.replaceOp(op, phase);
2787 return success();
2788 }
2789};
2790
2791struct DsBarrierStatePendingCountOpLowering
2792 : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
2794
2795 LogicalResult
2796 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
2797 ConversionPatternRewriter &rewriter) const override {
2798 Location loc = op.getLoc();
2799 Type i32 = rewriter.getI32Type();
2800
2801 Value state = adaptor.getState();
2802
2803 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
2804 Value pendingCount = LLVM::AndOp::create(
2805 rewriter, loc, noInitCount,
2806 createI32Constant(rewriter, loc,
2807 static_cast<uint32_t>(kDsBarrierPendingCountMask)));
2808
2809 rewriter.replaceOp(op, pendingCount);
2810 return success();
2811 }
2812};
2813
2814struct DsBarrierStateInitCountOpLowering
2815 : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
2817
2818 LogicalResult
2819 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
2820 ConversionPatternRewriter &rewriter) const override {
2821 Location loc = op.getLoc();
2822 Type i32 = rewriter.getI32Type();
2823
2824 Value state = adaptor.getState();
2825
2826 Value initCountI64 = LLVM::LShrOp::create(
2827 rewriter, loc, state,
2828 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
2829 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
2830
2831 rewriter.replaceOp(op, initCount);
2832 return success();
2833 }
2834};
2835
2836struct DsBarrierStatePhaseParityLowering
2837 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
2839
2840 LogicalResult
2841 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
2842 ConversionPatternRewriter &rewriter) const override {
2843 Location loc = op.getLoc();
2844 Type i1 = rewriter.getI1Type();
2845
2846 Value state = adaptor.getState();
2847
2848 Value noInitCount =
2849 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
2850 Value phase = LLVM::LShrOp::create(
2851 rewriter, loc, noInitCount,
2852 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
2853 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
2854
2855 rewriter.replaceOp(op, parity);
2856 return success();
2857 }
2858};
2859
2860//===----------------------------------------------------------------------===//
2861// Tensor Data Mover (TDM)
2862//===----------------------------------------------------------------------===//
2863
2864static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2865 Value accumulator, Value value, int64_t shift) {
2866 shift = shift % 32;
2867 Value shiftAmount;
2868 if (shift != 0) {
2869 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2870 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2871 }
2872
2873 if (matchPattern(accumulator, mlir::m_Zero()))
2874 return value;
2875
2876 constexpr bool isDisjoint = true;
2877 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2878}
2879
2880template <typename BaseOp>
2881struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
2882 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2883 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
2884
2885 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2886 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2887 Chipset chipset;
2888
2889 LogicalResult
2890 matchAndRewrite(BaseOp op, Adaptor adaptor,
2891 ConversionPatternRewriter &rewriter) const override {
2892 if (chipset < kGfx1250)
2893 return op->emitOpError("make_dma_base is only supported on gfx1250");
2894
2895 Location loc = op.getLoc();
2896
2897 constexpr int32_t constlen = 4;
2898 Value consts[constlen];
2899 for (int64_t i = 0; i < constlen; ++i)
2900 consts[i] = createI32Constant(rewriter, loc, i);
2901
2902 constexpr int32_t sgprslen = constlen;
2903 Value sgprs[sgprslen];
2904 for (int64_t i = 0; i < sgprslen; ++i) {
2905 sgprs[i] = consts[0];
2906 }
2907
2908 sgprs[0] = consts[1];
2909
2910 if constexpr (BaseOp::isGather()) {
2911 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2912
2913 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2914 Type indexType = type.getIndexType();
2915 unsigned indexSize = indexType.getIntOrFloatBitWidth();
2916 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2917 "expected index_size to be 16 or 32");
2918 unsigned idx = (indexSize / 16) - 1;
2919
2920 if (idx)
2921 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2922 }
2923
2924 ValueRange ldsIndices = adaptor.getLdsIndices();
2925 Value lds = adaptor.getLds();
2926 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2927
2929 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2930
2931 ValueRange globalIndices = adaptor.getGlobalIndices();
2932 Value global = adaptor.getGlobal();
2933 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2934
2936 rewriter, loc, globalMemRefType, global, globalIndices);
2937
2938 Type i32 = rewriter.getI32Type();
2939 Type i64 = rewriter.getI64Type();
2940
2941 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2942 Value castForGlobalAddr =
2943 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2944
2945 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2946
2947 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2948 createI64Constant(rewriter, loc, 32));
2949
2950 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2951
2952 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2953 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2954
2955 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2956
2957 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2958 assert(v4i32 && "expected type conversion to succeed");
2959 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2960
2961 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2962 result =
2963 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
2964
2965 rewriter.replaceOp(op, result);
2966 return success();
2967 }
2968};
2969
2970template <typename DescriptorOp>
2971struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
2972 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2973 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
2974
2975 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
2976 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2977 Chipset chipset;
2978
2979 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
2980
2981 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2982 ConversionPatternRewriter &rewriter, Location loc,
2983 Value sgpr0) const {
2984 Value mask = op.getWorkgroupMask();
2985 if (!mask)
2986 return sgpr0;
2987
2988 Type i16 = rewriter.getI16Type();
2989 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2990 Type i32 = rewriter.getI32Type();
2991 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2992 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2993 }
2994
2995 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2996 ConversionPatternRewriter &rewriter, Location loc,
2997 Value sgpr0, ArrayRef<Value> consts) const {
2998 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2999 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3000 "expected type width to be 8, 16, 32, or 64.");
3001 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3002 Value size = consts[idx];
3003 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3004 }
3005
3006 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3007 ConversionPatternRewriter &rewriter, Location loc,
3008 Value sgpr0, ArrayRef<Value> consts) const {
3009 if (!adaptor.getAtomicBarrierAddress())
3010 return sgpr0;
3011
3012 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3013 }
3014
3015 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3016 ConversionPatternRewriter &rewriter, Location loc,
3017 Value sgpr0, ArrayRef<Value> consts) const {
3018 if (!adaptor.getGlobalIncrement())
3019 return sgpr0;
3020
3021 // Value is ignored when in gather mode.
3022 // TODO: emit error earlier?
3023 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3024 }
3025
3026 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3027 ConversionPatternRewriter &rewriter, Location loc,
3028 Value sgpr0, ArrayRef<Value> consts) const {
3029 if (!op.getPadAmount())
3030 return sgpr0;
3031
3032 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3033 }
3034
3035 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3036 ConversionPatternRewriter &rewriter, Location loc,
3037 Value sgpr0, ArrayRef<Value> consts) const {
3038 if (!op.getWorkgroupMask())
3039 return sgpr0;
3040
3041 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3042 }
3043
3044 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3045 ConversionPatternRewriter &rewriter, Location loc,
3046 Value sgpr0, ArrayRef<Value> consts) const {
3047 if (!op.getPadAmount())
3048 return sgpr0;
3049
3050 // pre-condition: padInterval can be a power of two between 2 and 256.
3051 // TODO: Validation if the value breaks the pre-condition.
3052 // If the pre-condition fails, there is a possibility of
3053 // affecting the higher bits. In a following PR implement
3054 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3055 // checked at runtime.
3056 IntegerType i32 = rewriter.getI32Type();
3057 Value padInterval = adaptor.getPadInterval();
3058 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3059 padInterval, false);
3060 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3061 // post-condition: padInterval can be a value between 0 and 7.
3062 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3063 }
3064
3065 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3066 ConversionPatternRewriter &rewriter, Location loc,
3067 Value sgpr0, ArrayRef<Value> consts) const {
3068 if (!op.getPadAmount())
3069 return sgpr0;
3070
3071 // pre-condition: padAmount is a value between 1-128.
3072 // TODO: Validation if the value breaks the pre-condition.
3073 // If the pre-condition fails, there is a possibility of
3074 // affecting the higher bits. In a following PR implement
3075 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3076 // checked at runtime.
3077 Value padAmount = adaptor.getPadAmount();
3078 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3079 // post-condition: padAmount is a value between 0-127.
3080 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3081 }
3082
3083 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3084 ConversionPatternRewriter &rewriter,
3085 Location loc, Value sgpr1,
3086 ArrayRef<Value> consts) const {
3087 if (!adaptor.getAtomicBarrierAddress())
3088 return sgpr1;
3089
3090 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3091 auto barrierAddressTy =
3092 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3093 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3094 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
3095 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3096 atomicBarrierIndices);
3097 IntegerType i32 = rewriter.getI32Type();
3098 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
3099 // that the 3 LSBs are zero.
3100 // TODO: Validation if the value breaks the pre-condition.
3101 // In a following PR implement RuntimeVerifiableOpInterface
3102 // that instruments conditions that need to be checked at runtime.
3103 atomicBarrierAddress =
3104 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3105 atomicBarrierAddress =
3106 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3107 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
3108 atomicBarrierAddress =
3109 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3110 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3111 }
3112
3113 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3114 ConversionPatternRewriter &rewriter,
3115 Location loc, Value sgpr1, Value sgpr2,
3116 ArrayRef<Value> consts, uint64_t dimX,
3117 uint32_t offset) const {
3118 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3119 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3120 SmallVector<OpFoldResult> mixedGlobalSizes =
3121 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3122 if (mixedGlobalSizes.size() <= dimX)
3123 return {sgpr1, sgpr2};
3124
3125 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3126 // pre-condition: tensorDimX is less than 2^32-1
3127 // TODO: Validation if the value breaks the pre-condition.
3128 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3129 // conditions that need to be checked at runtime. This could also be fixed
3130 // by saying that mixedGlobalSizes is a DynamicI32List.
3131 Value tensorDimX;
3132 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3133 tensorDimX =
3134 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3135 } else {
3136 IntegerType i32 = rewriter.getI32Type();
3137 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3138 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3139 }
3140
3141 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3142
3143 Value c16 = createI32Constant(rewriter, loc, 16);
3144 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3145 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3146 return {sgpr1, sgpr2};
3147 }
3148
3149 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3150 ConversionPatternRewriter &rewriter,
3151 Location loc, Value sgpr1, Value sgpr2,
3152 ArrayRef<Value> consts) const {
3153 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3154 48);
3155 }
3156
3157 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3158 ConversionPatternRewriter &rewriter,
3159 Location loc, Value sgpr2, Value sgpr3,
3160 ArrayRef<Value> consts) const {
3161 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3162 80);
3163 }
3164
3165 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3166 ConversionPatternRewriter &rewriter, Location loc,
3167 Value sgpr, ArrayRef<Value> consts, size_t dimX,
3168 int64_t offset) const {
3169 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3170 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3171 SmallVector<OpFoldResult> mixedSharedSizes =
3172 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
3173 if (mixedSharedSizes.size() <= dimX)
3174 return sgpr;
3175
3176 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3177 // pre-condition: tileDimX is less than 2^16-1
3178 // TODO: Validation if the value breaks the pre-condition.
3179 // If the pre-condition fails, there is a possibility of
3180 // affecting the higher bits. In a following PR implement
3181 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3182 // checked at runtime. This could also be fixed by saying that
3183 // mixedSharedSizes is a DynamicI16List.
3184 Value tileDimX;
3185 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3186 tileDimX =
3187 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3188 } else {
3189 IntegerType i32 = rewriter.getI32Type();
3190 tileDimX = cast<Value>(tileDimXOpFoldResult);
3191 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3192 }
3193
3194 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3195 }
3196
3197 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3198 ConversionPatternRewriter &rewriter, Location loc,
3199 Value sgpr3, ArrayRef<Value> consts) const {
3200 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3201 }
3202
3203 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3204 ConversionPatternRewriter &rewriter, Location loc,
3205 Value sgpr4, ArrayRef<Value> consts) const {
3206 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3207 }
3208
3209 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3210 ConversionPatternRewriter &rewriter, Location loc,
3211 Value sgpr4, ArrayRef<Value> consts) const {
3212 auto type = cast<VectorType>(op.getIndices().getType());
3213 ArrayRef<int64_t> shape = type.getShape();
3214 assert(shape.size() == 1 && "expected shape to be of rank 1.");
3215 unsigned length = shape.back();
3216 assert(0 < length && length <= 16 && "expected length to be at most 16.");
3217 Value value = createI32Constant(rewriter, loc, length);
3218 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3219 }
3220
3221 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3222 ConversionPatternRewriter &rewriter,
3223 Location loc, Value sgpr4,
3224 ArrayRef<Value> consts) const {
3225 if constexpr (DescriptorOp::isGather())
3226 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3227 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3228 }
3229
3230 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3231 ConversionPatternRewriter &rewriter, Location loc,
3232 Value sgpr4, ArrayRef<Value> consts) const {
3233 // Value is ignored when in gather mode.
3234 if constexpr (DescriptorOp::isGather())
3235 return sgpr4;
3236 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3237 }
3238
3239 std::pair<Value, Value>
3240 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3241 ConversionPatternRewriter &rewriter, Location loc,
3242 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3243 size_t dimX, int64_t offset) const {
3244 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3245 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3246 SmallVector<OpFoldResult> mixedGlobalStrides =
3247 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3248
3249 if (mixedGlobalStrides.size() <= (dimX + 1))
3250 return {sgprY, sgprZ};
3251
3252 OpFoldResult tensorDimXStrideOpFoldResult =
3253 *(mixedGlobalStrides.rbegin() + dimX + 1);
3254 // pre-condition: tensorDimXStride is less than 2^48-1
3255 // TODO: Validation if the value breaks the pre-condition.
3256 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3257 // conditions that need to be checked at runtime.
3258 Value tensorDimXStride;
3259 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3260 tensorDimXStride =
3261 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3262 else
3263 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3264
3265 constexpr int64_t first48bits = (1ll << 48) - 1;
3266 Value mask = createI64Constant(rewriter, loc, first48bits);
3267 tensorDimXStride =
3268 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3269 IntegerType i32 = rewriter.getI32Type();
3270 Value tensorDimXStrideLow =
3271 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3272 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3273
3274 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3275 Value shiftVal = createI64Constant(rewriter, loc, shift);
3276 Value tensorDimXStrideHigh =
3277 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3278 tensorDimXStrideHigh =
3279 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3280 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3281 offset + shift);
3282 return {sgprY, sgprZ};
3283 }
3284
3285 std::pair<Value, Value>
3286 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3287 ConversionPatternRewriter &rewriter, Location loc,
3288 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3289 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3290 0, 160);
3291 }
3292
3293 std::pair<Value, Value>
3294 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3295 ConversionPatternRewriter &rewriter, Location loc,
3296 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3297 // Value is ignored when in gather mode.
3298 if constexpr (DescriptorOp::isGather())
3299 return {sgpr5, sgpr6};
3300 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3301 1, 208);
3302 }
3303
3304 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3305 ConversionPatternRewriter &rewriter, Location loc,
3306 ArrayRef<Value> consts) const {
3307 Value sgprs[8];
3308 for (int64_t i = 0; i < 8; ++i) {
3309 sgprs[i] = consts[0];
3310 }
3311
3312 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3313 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3314 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3315 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3316 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3317 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3318 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3319 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3320
3321 sgprs[1] =
3322 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3323 std::tie(sgprs[1], sgprs[2]) =
3324 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3325 std::tie(sgprs[2], sgprs[3]) =
3326 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3327
3328 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3329 sgprs[4] =
3330 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3331 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3332 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3333 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3334 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3335 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3336
3337 IntegerType i32 = rewriter.getI32Type();
3338 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3339 assert(v8i32 && "expected type conversion to succeed");
3340 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3341
3342 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3343 dgroup1 =
3344 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3345 }
3346
3347 return dgroup1;
3348 }
3349
3350 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3351 ConversionPatternRewriter &rewriter, Location loc,
3352 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3353 int64_t offset) const {
3354 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3355 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3356 SmallVector<OpFoldResult> mixedGlobalSizes =
3357 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3358 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3359 return sgpr0;
3360
3361 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3362 Value tensorDimX;
3363 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3364 tensorDimX =
3365 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3366 } else {
3367 IntegerType i32 = rewriter.getI32Type();
3368 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3369 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3370 }
3371
3372 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3373 }
3374
3375 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3376 ConversionPatternRewriter &rewriter, Location loc,
3377 Value sgpr0, ArrayRef<Value> consts) const {
3378 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3379 }
3380
3381 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3382 Location loc, Value accumulator,
3383 Value value, int64_t shift) const {
3384
3385 IntegerType i32 = rewriter.getI32Type();
3386 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3387 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3388 }
3389
3390 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3391 ConversionPatternRewriter &rewriter, Location loc,
3392 Value sgpr1, ArrayRef<Value> consts,
3393 int64_t offset) const {
3394 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3395 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3396 }
3397
3398 std::pair<Value, Value>
3399 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3400 ConversionPatternRewriter &rewriter, Location loc,
3401 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3402 int64_t offset) const {
3403 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3404 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3405 globalAddrIncrement, offset);
3406 Value shift = createI64Constant(rewriter, loc, 32);
3407 globalAddrIncrement =
3408 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3409 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3410 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3411 globalAddrIncrement, offset + 32);
3412 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3413 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3414 return {sgpr2, sgpr3};
3415 }
3416
3417 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3418 ConversionPatternRewriter &rewriter,
3419 Location loc, Value sgpr1,
3420 ArrayRef<Value> consts) const {
3421 Value ldsIncrement = op.getLdsIncrement();
3422 constexpr int64_t dim = 3;
3423 constexpr int64_t offset = 32;
3424 if (!ldsIncrement)
3425 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3426 offset);
3427 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3428 offset);
3429 }
3430
3431 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3432 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3433 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3434 Value globalIncrement = op.getGlobalIncrement();
3435 constexpr int32_t dim = 2;
3436 constexpr int32_t offset = 64;
3437 if (!globalIncrement)
3438 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3439 consts, dim, offset);
3440 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3441 consts, offset);
3442 }
3443
3444 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3445 ConversionPatternRewriter &rewriter, Location loc,
3446 Value sgpr3, ArrayRef<Value> consts,
3447 int32_t offset) const {
3448 Value iterationCount = adaptor.getIterationCount();
3449 IntegerType i32 = rewriter.getI32Type();
3450 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3451 // TODO: validation if the value breaks the pre-condition.
3452 // If the pre-condition fails, there is a possibility of
3453 // affecting the higher bits. In a following PR implement
3454 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3455 // checked at runtime.
3456 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3457 iterationCount =
3458 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3459 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3460 }
3461
3462 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3463 ConversionPatternRewriter &rewriter,
3464 Location loc, Value sgpr3,
3465 ArrayRef<Value> consts) const {
3466 Value iterateCount = op.getIterationCount();
3467 constexpr int32_t dim = 2;
3468 constexpr int32_t offset = 112;
3469 if (!iterateCount)
3470 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3471 offset);
3472
3473 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3474 }
3475
3476 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3477 ConversionPatternRewriter &rewriter, Location loc,
3478 ArrayRef<Value> consts) const {
3479 if constexpr (DescriptorOp::isGather())
3480 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3481 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3482 }
3483
3484 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3485 ConversionPatternRewriter &rewriter, Location loc,
3486 ArrayRef<Value> consts) const {
3487 IntegerType i32 = rewriter.getI32Type();
3488 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3489 assert(v4i32 && "expected type conversion to succeed.");
3490
3491 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3492 if (onlyNeedsTwoDescriptors)
3493 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3494
3495 constexpr int64_t sgprlen = 4;
3496 Value sgprs[sgprlen];
3497 for (int i = 0; i < sgprlen; ++i)
3498 sgprs[i] = consts[0];
3499
3500 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3501 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3502 sgprs[1], consts);
3503 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3504 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3505 sgprs[3] =
3506 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3507
3508 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3509 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3510 dgroup2 =
3511 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3512
3513 return dgroup2;
3514 }
3515
3516 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3517 ConversionPatternRewriter &rewriter, Location loc,
3518 ArrayRef<Value> consts, bool firstHalf) const {
3519 IntegerType i32 = rewriter.getI32Type();
3520 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3521 assert(v4i32 && "expected type conversion to succeed.");
3522
3523 Value indices = adaptor.getIndices();
3524 auto vectorType = cast<VectorType>(indices.getType());
3525 unsigned length = vectorType.getShape().back();
3526 Type elementType = vectorType.getElementType();
3527 unsigned maxLength = elementType == i32 ? 4 : 8;
3528 int32_t offset = firstHalf ? 0 : maxLength;
3529 unsigned discountedLength =
3530 std::max(static_cast<int32_t>(length - offset), 0);
3531
3532 unsigned targetSize = std::min(maxLength, discountedLength);
3533
3534 SmallVector<Value> indicesVector;
3535 for (unsigned i = offset; i < targetSize + offset; ++i) {
3536 Value idx;
3537 if (i < consts.size())
3538 idx = consts[i];
3539 else
3540 idx = createI32Constant(rewriter, loc, i);
3541 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
3542 indicesVector.push_back(elem);
3543 }
3544
3545 SmallVector<Value> indicesI32Vector;
3546 if (elementType == i32) {
3547 indicesI32Vector = indicesVector;
3548 } else {
3549 for (unsigned i = 0; i < targetSize; ++i) {
3550 Value index = indicesVector[i];
3551 indicesI32Vector.push_back(
3552 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3553 }
3554 if ((targetSize % 2) != 0)
3555 // Add padding when not divisible by two.
3556 indicesI32Vector.push_back(consts[0]);
3557 }
3558
3559 SmallVector<Value> indicesToInsert;
3560 if (elementType == i32) {
3561 indicesToInsert = indicesI32Vector;
3562 } else {
3563 unsigned size = indicesI32Vector.size() / 2;
3564 for (unsigned i = 0; i < size; ++i) {
3565 Value first = indicesI32Vector[2 * i];
3566 Value second = indicesI32Vector[2 * i + 1];
3567 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3568 indicesToInsert.push_back(joined);
3569 }
3570 }
3571
3572 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3573 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3574 dgroup =
3575 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3576
3577 return dgroup;
3578 }
3579
3580 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3581 ConversionPatternRewriter &rewriter, Location loc,
3582 ArrayRef<Value> consts) const {
3583 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
3584 }
3585
3586 std::pair<Value, Value>
3587 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3588 ConversionPatternRewriter &rewriter, Location loc,
3589 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
3590 constexpr int32_t dim = 3;
3591 constexpr int32_t offset = 0;
3592 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3593 dim, offset);
3594 }
3595
3596 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3597 ConversionPatternRewriter &rewriter,
3598 Location loc, Value sgpr1, Value sgpr2,
3599 ArrayRef<Value> consts) const {
3600 constexpr int32_t dim = 4;
3601 constexpr int32_t offset = 48;
3602 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3603 offset);
3604 }
3605
3606 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3607 ConversionPatternRewriter &rewriter, Location loc,
3608 Value sgpr2, ArrayRef<Value> consts) const {
3609 constexpr int32_t dim = 4;
3610 constexpr int32_t offset = 80;
3611 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3612 }
3613
3614 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3615 ConversionPatternRewriter &rewriter, Location loc,
3616 ArrayRef<Value> consts) const {
3617 if constexpr (DescriptorOp::isGather())
3618 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3619 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3620 }
3621
3622 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3623 ConversionPatternRewriter &rewriter, Location loc,
3624 ArrayRef<Value> consts) const {
3625 IntegerType i32 = rewriter.getI32Type();
3626 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3627 assert(v4i32 && "expected type conversion to succeed.");
3628 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3629 if (onlyNeedsTwoDescriptors)
3630 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3631
3632 constexpr int32_t sgprlen = 4;
3633 Value sgprs[sgprlen];
3634 for (int i = 0; i < sgprlen; ++i)
3635 sgprs[i] = consts[0];
3636
3637 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3638 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3639 std::tie(sgprs[1], sgprs[2]) =
3640 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3641 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3642
3643 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3644 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3645 dgroup3 =
3646 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3647
3648 return dgroup3;
3649 }
3650
3651 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3652 ConversionPatternRewriter &rewriter, Location loc,
3653 ArrayRef<Value> consts) const {
3654 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
3655 }
3656
3657 LogicalResult
3658 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3659 ConversionPatternRewriter &rewriter) const override {
3660 if (chipset < kGfx1250)
3661 return op->emitOpError(
3662 "make_dma_descriptor is only supported on gfx1250");
3663
3664 Location loc = op.getLoc();
3665
3666 SmallVector<Value> consts;
3667 for (int64_t i = 0; i < 8; ++i)
3668 consts.push_back(createI32Constant(rewriter, loc, i));
3669
3670 Value dgroup0 = this->getDGroup0(adaptor);
3671 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3672 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3673 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3674 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3675 rewriter.replaceOpWithMultiple(op, {results});
3676 return success();
3677 }
3678};
3679
3680template <typename SourceOp, typename TargetOp>
3681struct AMDGPUTensorLoadStoreOpLowering
3682 : public ConvertOpToLLVMPattern<SourceOp> {
3683 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3685 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
3686 Chipset chipset)
3687 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3688 Chipset chipset;
3689
3690 LogicalResult
3691 matchAndRewrite(SourceOp op, Adaptor adaptor,
3692 ConversionPatternRewriter &rewriter) const override {
3693 if (chipset < kGfx1250)
3694 return op->emitOpError("is only supported on gfx1250");
3695
3696 ValueRange desc = adaptor.getDesc();
3697 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3698 desc[3], /*cachePolicy=*/0,
3699 /*alias_scopes=*/nullptr,
3700 /*noalias_scopes=*/nullptr,
3701 /*tbaa=*/nullptr);
3702 return success();
3703 }
3704};
3705
3706struct ConvertAMDGPUToROCDLPass
3707 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3708 using Base::Base;
3709
3710 void runOnOperation() override {
3711 MLIRContext *ctx = &getContext();
3712 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
3713 if (failed(maybeChipset)) {
3714 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
3715 return signalPassFailure();
3716 }
3717
3718 RewritePatternSet patterns(ctx);
3719 LLVMTypeConverter converter(ctx);
3720
3721 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
3722 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3723 LLVMConversionTarget target(getContext());
3724 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3725 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3726 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3727 if (failed(applyPartialConversion(getOperation(), target,
3728 std::move(patterns))))
3729 signalPassFailure();
3730 }
3731};
3732} // namespace
3733
3735 TypeConverter &typeConverter) {
3737 typeConverter, [](gpu::AddressSpace space) {
3738 switch (space) {
3739 case gpu::AddressSpace::Global:
3740 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3741 case gpu::AddressSpace::Workgroup:
3742 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3743 case gpu::AddressSpace::Private:
3744 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3745 }
3746 llvm_unreachable("unknown address space enum value");
3747 });
3748}
3749
3751 TypeConverter &typeConverter) {
3752 typeConverter.addTypeAttributeConversion(
3753 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
3754 -> TypeConverter::AttributeConversionResult {
3755 MLIRContext *ctx = as.getContext();
3756 Type i64 = IntegerType::get(ctx, 64);
3757 switch (as.getValue()) {
3758 case amdgpu::AddressSpace::FatRawBuffer:
3759 return IntegerAttr::get(i64, 7);
3760 case amdgpu::AddressSpace::BufferRsrc:
3761 return IntegerAttr::get(i64, 8);
3762 case amdgpu::AddressSpace::FatStructuredBuffer:
3763 return IntegerAttr::get(i64, 9);
3764 }
3765 return TypeConverter::AttributeConversionResult::abort();
3766 });
3767 typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
3768 return IntegerType::get(type.getContext(), 64);
3769 });
3770 typeConverter.addConversion([&](TDMBaseType type) -> Type {
3771 Type i32 = IntegerType::get(type.getContext(), 32);
3772 return typeConverter.convertType(VectorType::get(4, i32));
3773 });
3774 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
3775 Type i32 = IntegerType::get(type.getContext(), 32);
3776 return typeConverter.convertType(VectorType::get(4, i32));
3777 });
3778 typeConverter.addConversion(
3779 [&](TDMDescriptorType type,
3780 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
3781 Type i32 = IntegerType::get(type.getContext(), 32);
3782 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3783 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3784 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
3785 return success();
3786 });
3787
3788 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
3789 ValueRange inputs,
3791 // Only create unrealized_conversion_cast for TDMDescriptorType.
3792 // All other types which are not expected, should be
3793 // materialized by other target materialization functions.
3794 if (inputs.size() != 1)
3795 return {};
3796
3797 if (!isa<TDMDescriptorType>(inputs[0].getType()))
3798 return {};
3799
3800 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3801 return cast.getResults();
3802 };
3803
3804 typeConverter.addTargetMaterialization(addUnrealizedCast);
3805}
3806
3809 Chipset chipset) {
3811 patterns
3812 .add<FatRawBufferCastLowering,
3813 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3814 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3815 RawBufferOpLowering<RawBufferAtomicFaddOp,
3816 ROCDL::RawPtrBufferAtomicFaddOp>,
3817 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3818 ROCDL::RawPtrBufferAtomicFmaxOp>,
3819 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3820 ROCDL::RawPtrBufferAtomicSmaxOp>,
3821 RawBufferOpLowering<RawBufferAtomicUminOp,
3822 ROCDL::RawPtrBufferAtomicUminOp>,
3823 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3824 ROCDL::RawPtrBufferAtomicCmpSwap>,
3825 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3826 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3827 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3828 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3829 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3830 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3831 GatherToLDSOpLowering, TransposeLoadOpLowering,
3832 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3833 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3834 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3835 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3836 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3837 ROCDL::TensorLoadToLDSOp>,
3838 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3839 ROCDL::TensorStoreFromLDSOp>,
3840 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
3841 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering>(converter,
3842 chipset);
3843 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
3844 DsBarrierStatePendingCountOpLowering,
3845 DsBarrierStateInitCountOpLowering,
3846 DsBarrierStatePhaseParityLowering>(converter);
3847}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:219
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:218
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14