MLIR 22.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Matchers.h"
24#include "mlir/Pass/Pass.h"
25
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Casting.h"
31#include "llvm/Support/ErrorHandling.h"
32#include <optional>
33
34namespace mlir {
35#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40using namespace mlir::amdgpu;
41
42// Define commonly used chipsets versions for convenience.
43constexpr Chipset kGfx908 = Chipset(9, 0, 8);
44constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
45constexpr Chipset kGfx942 = Chipset(9, 4, 2);
46constexpr Chipset kGfx950 = Chipset(9, 5, 0);
47constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
48
49/// Convert an unsigned number `val` to i32.
50static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
51 Location loc, Value val) {
52 IntegerType i32 = rewriter.getI32Type();
53 // Force check that `val` is of int type.
54 auto valTy = cast<IntegerType>(val.getType());
55 if (i32 == valTy)
56 return val;
57 return valTy.getWidth() > 32
58 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
59 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
60}
61
62static Value createI32Constant(ConversionPatternRewriter &rewriter,
63 Location loc, int32_t value) {
64 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
65}
66
67/// Convert an unsigned number `val` to i64.
68static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
69 Location loc, Value val) {
70 IntegerType i64 = rewriter.getI64Type();
71 // Force check that `val` is of int type.
72 auto valTy = cast<IntegerType>(val.getType());
73 if (i64 == valTy)
74 return val;
75 return valTy.getWidth() > 64
76 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
77 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
78}
79
80static Value createI64Constant(ConversionPatternRewriter &rewriter,
81 Location loc, int64_t value) {
82 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
83}
84
85/// Returns the linear index used to access an element in the memref.
86static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
87 Location loc, MemRefDescriptor &memRefDescriptor,
89 IntegerType i32 = rewriter.getI32Type();
91 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
92 if (stride != 1) { // Skip if stride is 1.
93 Value strideValue =
94 ShapedType::isDynamic(stride)
95 ? convertUnsignedToI32(rewriter, loc,
96 memRefDescriptor.stride(rewriter, loc, i))
97 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
98 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
99 }
100 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
101 : increment;
102 }
103 return index ? index : createI32Constant(rewriter, loc, 0);
104}
105
106/// Compute the contents of the `num_records` field for a given memref
107/// descriptor - that is, the number of bytes that's one element past the
108/// greatest possible valid index into the memref.
109static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
110 MemRefType memrefType,
111 MemRefDescriptor &memrefDescriptor,
112 ArrayRef<int64_t> strides,
113 int64_t elementByteWidth) {
114 if (memrefType.hasStaticShape() &&
115 !llvm::any_of(strides, ShapedType::isDynamic)) {
116 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
117 ArrayRef<int64_t> shape = memrefType.getShape();
118 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
119 size = std::max(shape[i] * strides[i], size);
120 size = size * elementByteWidth;
121 return createI64Constant(rewriter, loc, size);
122 }
123 Value maxIndex;
124 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
125 Value size = memrefDescriptor.size(rewriter, loc, i);
126 Value stride = memrefDescriptor.stride(rewriter, loc, i);
127 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
128 maxIndex = maxIndex
129 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
130 : maxThisDim;
131 }
132 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
133 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
134 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
135}
136
137static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
138 Value basePointer, Value numRecords,
139 bool boundsCheck, amdgpu::Chipset chipset,
140 Value cacheSwizzleStride = nullptr,
141 unsigned addressSpace = 8) {
142 // The stride value is generally 0. However, on MI-300 and onward, you can
143 // enable a cache swizzling mode by setting bit 14 of the stride field
144 // and setting that stride to a cache stride.
145 Type i16 = rewriter.getI16Type();
146 Value stride;
147 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
148 Value cacheStrideZext =
149 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
150 Value swizzleBit = LLVM::ConstantOp::create(
151 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
152 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
153 /*isDisjoint=*/true);
154 } else {
155 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
156 rewriter.getI16IntegerAttr(0));
157 }
158 // Get the number of elements.
159 // Flag word:
160 // bits 0-11: dst sel, ignored by these intrinsics
161 // bits 12-14: data format (ignored, must be nonzero, 7=float)
162 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
163 // bit 19: In nested heap (0 here)
164 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
165 // bits 21-22: Index stride for swizzles (N/A)
166 // bit 23: Add thread ID (0)
167 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
168 // bits 25-26: Reserved (0)
169 // bit 27: Buffer is non-volatile (CDNA only)
170 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
171 // none, 3 = either swizzles or testing against offset field) RDNA only
172 // bits 30-31: Type (must be 0)
173 uint32_t flags = (7 << 12) | (4 << 15);
174 if (chipset.majorVersion >= 10) {
175 flags |= (1 << 24);
176 uint32_t oob = boundsCheck ? 3 : 2;
177 flags |= (oob << 28);
178 }
179 Value flagsConst = createI32Constant(rewriter, loc, flags);
180 Type rsrcType =
181 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
182 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
183 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
184 return resource;
185}
186
187namespace {
188struct FatRawBufferCastLowering
189 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
190 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
191 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
192 chipset(chipset) {}
193
194 Chipset chipset;
195
196 LogicalResult
197 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter) const override {
199 Location loc = op.getLoc();
200 Value memRef = adaptor.getSource();
201 Value unconvertedMemref = op.getSource();
202 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
203 MemRefDescriptor descriptor(memRef);
204
205 DataLayout dataLayout = DataLayout::closest(op);
206 int64_t elementByteWidth =
207 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
208
209 int64_t unusedOffset = 0;
210 SmallVector<int64_t, 5> strideVals;
211 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
212 return op.emitOpError("Can't lower non-stride-offset memrefs");
213
214 Value numRecords = adaptor.getValidBytes();
215 if (!numRecords)
216 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
217 strideVals, elementByteWidth);
218
219 Value basePointer =
220 adaptor.getResetOffset()
221 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
222 memrefType)
223 : descriptor.alignedPtr(rewriter, loc);
224
225 Value offset = adaptor.getResetOffset()
226 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
227 rewriter.getIndexAttr(0))
228 : descriptor.offset(rewriter, loc);
229
230 bool hasSizes = memrefType.getRank() > 0;
231 // No need to unpack() and pack() all the individual sizes and strides,
232 // so we'll just extract the arrays.
233 Value sizes = hasSizes
234 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
236 : Value{};
237 Value strides =
238 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
240 : Value{};
241
242 Value fatPtr = makeBufferRsrc(
243 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
244 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
245
246 Value result = MemRefDescriptor::poison(
247 rewriter, loc,
248 getTypeConverter()->convertType(op.getResult().getType()));
249 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
250 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
251 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
253 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
255 if (hasSizes) {
256 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
258 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
260 }
261 rewriter.replaceOp(op, result);
262 return success();
263 }
264};
265
266/// Define lowering patterns for raw buffer ops
267template <typename GpuOp, typename Intrinsic>
268struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
269 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
270 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
271
272 Chipset chipset;
273 static constexpr uint32_t maxVectorOpWidth = 128;
274
275 LogicalResult
276 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 Location loc = gpuOp.getLoc();
279 Value memref = adaptor.getMemref();
280 Value unconvertedMemref = gpuOp.getMemref();
281 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
282
283 if (chipset.majorVersion < 9)
284 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
285
286 Value storeData = adaptor.getODSOperands(0)[0];
287 if (storeData == memref) // no write component to this op
288 storeData = Value();
289 Type wantedDataType;
290 if (storeData)
291 wantedDataType = storeData.getType();
292 else
293 wantedDataType = gpuOp.getODSResults(0)[0].getType();
294
295 Value atomicCmpData = Value();
296 // Operand index 1 of a load is the indices, trying to read them can crash.
297 if (storeData) {
298 Value maybeCmpData = adaptor.getODSOperands(1)[0];
299 if (maybeCmpData != memref)
300 atomicCmpData = maybeCmpData;
301 }
302
303 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
304
305 Type i32 = rewriter.getI32Type();
306
307 // Get the type size in bytes.
308 DataLayout dataLayout = DataLayout::closest(gpuOp);
309 int64_t elementByteWidth =
310 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
311 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
312
313 // If we want to load a vector<NxT> with total size <= 32
314 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
315 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
316 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
317 // so bitcast any floats to integers.
318 Type llvmBufferValType = llvmWantedDataType;
319 if (atomicCmpData) {
320 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
321 llvmBufferValType = this->getTypeConverter()->convertType(
322 rewriter.getIntegerType(floatType.getWidth()));
323 }
324 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
325 uint32_t vecLen = dataVector.getNumElements();
326 uint32_t elemBits =
327 dataLayout.getTypeSizeInBits(dataVector.getElementType());
328 uint32_t totalBits = elemBits * vecLen;
329 bool usePackedFp16 =
330 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
331 if (totalBits > maxVectorOpWidth)
332 return gpuOp.emitOpError(
333 "Total width of loads or stores must be no more than " +
334 Twine(maxVectorOpWidth) + " bits, but we call for " +
335 Twine(totalBits) +
336 " bits. This should've been caught in validation");
337 if (!usePackedFp16 && elemBits < 32) {
338 if (totalBits > 32) {
339 if (totalBits % 32 != 0)
340 return gpuOp.emitOpError("Load or store of more than 32-bits that "
341 "doesn't fit into words. Can't happen\n");
342 llvmBufferValType = this->typeConverter->convertType(
343 VectorType::get(totalBits / 32, i32));
344 } else {
345 llvmBufferValType = this->typeConverter->convertType(
346 rewriter.getIntegerType(totalBits));
347 }
348 }
349 }
350 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
351 // Buffer intrinsics doesn't support 1-element vectors, cast them to
352 // scalars.
353 if (vecType.getNumElements() == 1)
354 llvmBufferValType = vecType.getElementType();
355 }
356
357 SmallVector<Value, 6> args;
358 if (storeData) {
359 if (llvmBufferValType != llvmWantedDataType) {
360 Value castForStore = LLVM::BitcastOp::create(
361 rewriter, loc, llvmBufferValType, storeData);
362 args.push_back(castForStore);
363 } else {
364 args.push_back(storeData);
365 }
366 }
367
368 if (atomicCmpData) {
369 if (llvmBufferValType != llvmWantedDataType) {
370 Value castForCmp = LLVM::BitcastOp::create(
371 rewriter, loc, llvmBufferValType, atomicCmpData);
372 args.push_back(castForCmp);
373 } else {
374 args.push_back(atomicCmpData);
375 }
376 }
377
378 // Construct buffer descriptor from memref, attributes
379 int64_t offset = 0;
380 SmallVector<int64_t, 5> strides;
381 if (failed(memrefType.getStridesAndOffset(strides, offset)))
382 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
383
384 MemRefDescriptor memrefDescriptor(memref);
385
386 Value ptr = memrefDescriptor.bufferPtr(
387 rewriter, loc, *this->getTypeConverter(), memrefType);
388 Value numRecords = getNumRecords(
389 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
390 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
391 adaptor.getBoundsCheck(), chipset);
392 args.push_back(resource);
393
394 // Indexing (voffset)
395 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
396 adaptor.getIndices(), strides);
397 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
398 indexOffset && *indexOffset > 0) {
399 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
400 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
401 extraOffsetConst)
402 : extraOffsetConst;
403 }
404 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
405 args.push_back(voffset);
406
407 // SGPR offset.
408 Value sgprOffset = adaptor.getSgprOffset();
409 if (!sgprOffset)
410 sgprOffset = createI32Constant(rewriter, loc, 0);
411 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
412 args.push_back(sgprOffset);
413
414 // bit 0: GLC = 0 (atomics drop value, less coherency)
415 // bits 1-2: SLC, DLC = 0 (similarly)
416 // bit 3: swizzled (0 for raw)
417 args.push_back(createI32Constant(rewriter, loc, 0));
418
419 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
420 llvmBufferValType);
421 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
422 ArrayRef<NamedAttribute>());
423 if (lowered->getNumResults() == 1) {
424 Value replacement = lowered->getResult(0);
425 if (llvmBufferValType != llvmWantedDataType) {
426 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
428 }
429 rewriter.replaceOp(gpuOp, replacement);
430 } else {
431 rewriter.eraseOp(gpuOp);
432 }
433 return success();
434 }
435};
436
437// TODO: AMDGPU backend already have all this bitpacking logic, we should move
438// it to some common place.
439/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
440/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
441/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
442/// Vmcnt = Waitcnt[15:10] (gfx11)
443/// Expcnt = Waitcnt[6:4] (pre-gfx11)
444/// Expcnt = Waitcnt[2:0] (gfx11)
445/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
446/// Lgkmcnt = Waitcnt[13:8] (gfx10)
447/// Lgkmcnt = Waitcnt[9:4] (gfx11)
448static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
449 unsigned expcnt, unsigned lgkmcnt) {
450 if (chipset.majorVersion < 9) {
451 vmcnt = std::min(15u, vmcnt);
452 expcnt = std::min(7u, expcnt);
453 lgkmcnt = std::min(15u, lgkmcnt);
454 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
455 }
456 if (chipset.majorVersion == 9) {
457 vmcnt = std::min(63u, vmcnt);
458 expcnt = std::min(7u, expcnt);
459 lgkmcnt = std::min(15u, lgkmcnt);
460 unsigned lowBits = vmcnt & 0xF;
461 unsigned highBits = (vmcnt >> 4) << 14;
462 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
463 return lowBits | highBits | otherCnts;
464 }
465 if (chipset.majorVersion == 10) {
466 vmcnt = std::min(63u, vmcnt);
467 expcnt = std::min(7u, expcnt);
468 lgkmcnt = std::min(63u, lgkmcnt);
469 unsigned lowBits = vmcnt & 0xF;
470 unsigned highBits = (vmcnt >> 4) << 14;
471 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
472 return lowBits | highBits | otherCnts;
473 }
474 if (chipset.majorVersion == 11) {
475 vmcnt = std::min(63u, vmcnt);
476 expcnt = std::min(7u, expcnt);
477 lgkmcnt = std::min(63u, lgkmcnt);
478 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
479 }
480 return failure();
481}
482
483struct MemoryCounterWaitOpLowering
484 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
485 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
487 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
488 chipset(chipset) {}
491
492 LogicalResult
493 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const override {
495 if (chipset.majorVersion >= 12) {
496 Location loc = op.getLoc();
497 if (std::optional<int> ds = adaptor.getDs())
498 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
500 if (std::optional<int> load = adaptor.getLoad())
501 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
502
503 if (std::optional<int> store = adaptor.getStore())
504 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
505
506 if (std::optional<int> exp = adaptor.getExp())
507 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
508
509 if (std::optional<int> tensor = adaptor.getTensor())
510 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
511
512 rewriter.eraseOp(op);
513 return success();
514 }
515
516 if (adaptor.getTensor())
517 return op.emitOpError("unsupported chipset");
518
519 auto getVal = [](Attribute attr) -> unsigned {
520 if (attr)
521 return cast<IntegerAttr>(attr).getInt();
522
523 // This value will be clamped to the maximum value for the chipset.
524 return 1024;
525 };
526 unsigned ds = getVal(adaptor.getDsAttr());
527 unsigned exp = getVal(adaptor.getExpAttr());
528
529 unsigned vmcnt = 1024;
530 Attribute load = adaptor.getLoadAttr();
531 Attribute store = adaptor.getStoreAttr();
532 if (load && store) {
533 vmcnt = getVal(load) + getVal(store);
534 } else if (load) {
535 vmcnt = getVal(load);
536 } else if (store) {
537 vmcnt = getVal(store);
538 }
539
540 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
541 if (failed(waitcnt))
542 return op.emitOpError("unsupported chipset");
543
544 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
545 return success();
546 }
547};
548
549struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
550 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
551 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
552
553 Chipset chipset;
554
555 LogicalResult
556 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
557 ConversionPatternRewriter &rewriter) const override {
558 Location loc = op.getLoc();
559 // This ensures that waits on global memory aren't introduced on
560 // chips that don't have the BackOffBarrier feature enabled in LLVM.
561 bool requiresInlineAsm = chipset < kGfx90a;
562
563 Attribute mmra =
564 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
565 // Note: while there *is* a workgroup-one-as scope, this, when combined with
566 // the MMRA, will lead to the fence having no effect. This is because the
567 // codepaths for an atomic load or store will observe that a
568 // one-address-space atomic to LDS requires no synchronization because
569 // operations on LDS are totally ordered with respect to each other, and so
570 // will not emit the correct waitcnt operations that these fences are
571 // intended to produce. Therefore, we use a broader type of fence and rely
572 // on the MMRA to relax it to the semantics we want.
573 StringRef scope = "workgroup";
574
575 auto relFence = LLVM::FenceOp::create(rewriter, loc,
576 LLVM::AtomicOrdering::release, scope);
577 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
578 if (requiresInlineAsm) {
579 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
580 LLVM::AsmDialect::AD_ATT);
581 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
582 const char *constraints = "";
583 LLVM::InlineAsmOp::create(
584 rewriter, loc,
585 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
586 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
587 /*is_align_stack=*/false, LLVM::TailCallKind::None,
588 /*asm_dialect=*/asmDialectAttr,
589 /*operand_attrs=*/ArrayAttr());
590 } else if (chipset.majorVersion < 12) {
591 ROCDL::SBarrierOp::create(rewriter, loc);
592 } else {
593 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
594 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
595 }
596
597 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
598 LLVM::AtomicOrdering::acquire, scope);
599 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
600 rewriter.replaceOp(op, acqFence);
601 return success();
602 }
603};
604
605struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
606 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
607 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
608
609 Chipset chipset;
610
611 LogicalResult
612 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
613 ConversionPatternRewriter &rewriter) const override {
614 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
615 (uint32_t)op.getOpts());
616 return success();
617 }
618};
619
620} // namespace
621
622/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
623/// and LLVM AMDGPU intrinsics convention.
624///
625/// Specifically:
626/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
627/// allows bf16. Newer MFMAs support bf16 types on operand, check
628/// IntrinsicsAMDGPU.td file for reference.
629/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
630/// instead, which is what the f8f6f4 intrinsics use.
631/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
632/// integer.
633///
634/// Note that the type of `input` has already been LLVM type converted:
635/// therefore 8-bit and smaller floats are represented as their corresponding
636/// `iN` integers.
637static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
638 Location loc, Value input,
639 bool allowBf16 = true) {
640 Type inputType = input.getType();
641 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
642 if (vectorType.getElementType().isBF16() && !allowBf16)
643 return LLVM::BitcastOp::create(
644 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
645 if (vectorType.getElementType().isInteger(8) &&
646 vectorType.getNumElements() <= 8)
647 return LLVM::BitcastOp::create(
648 rewriter, loc,
649 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
650 if (isa<IntegerType>(vectorType.getElementType()) &&
651 vectorType.getElementTypeBitWidth() <= 8) {
652 int64_t numWords = llvm::divideCeil(
653 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
654 32);
655 return LLVM::BitcastOp::create(
656 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
657 input);
658 }
659 }
660 return input;
661}
662
663/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
664/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
665///
666/// Specifically:
667/// 1. If `input` is a i8 value, zero extend it to i32
668/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
669///
670/// Note that the type of `input` has already been LLVM type converted:
671/// therefore 8-bit and smaller floats are represented as their corresponding
672/// `iN` integers.
673static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
674 Location loc, Value input) {
675 Type inputType = input.getType();
676 Type outputType = rewriter.getI32Type();
677 if (auto intType = dyn_cast<IntegerType>(inputType))
678 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
679 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
680}
681
682/// Push an input operand. If it is a float type, nothing to do. If it is
683/// an integer type, then we need to also push its signdness (1 for signed, 0
684/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
685/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
686/// We also need to convert bfloat inputs to i16 to account for the bfloat
687/// intrinsics having been defined before the AMD backend supported bfloat. We
688/// similarly need to pack 8-bit float types into integers as if they were i8
689/// (which they are for the backend's purposes).
691 ConversionPatternRewriter &rewriter, Location loc,
692 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
693 Value mlirInput, SmallVectorImpl<Value> &operands,
694 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
695 Type inputType = llvmInput.getType();
696 auto vectorType = dyn_cast<VectorType>(inputType);
697 if (!vectorType) {
698 operands.push_back(llvmInput);
699 return;
700 }
701 Type elemType = vectorType.getElementType();
702 if (elemType.getIntOrFloatBitWidth() > 8) {
703 operands.push_back(llvmInput);
704 return;
705 }
706
707 // We need to check the type of the input before conversion to properly test
708 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
709 // fp8/int8 information is lost during the conversion process.
710 auto mlirInputType = cast<VectorType>(mlirInput.getType());
711 bool isInputInteger = mlirInputType.getElementType().isInteger();
712 if (isInputInteger) {
713 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
714 bool localIsUnsigned = isUnsigned;
715 if (elemType.isUnsignedInteger()) {
716 localIsUnsigned = true;
717 } else if (elemType.isSignedInteger()) {
718 localIsUnsigned = false;
719 }
720 attrs.push_back(
721 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
722 }
723
724 int64_t numBits =
725 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
726 Type i32 = rewriter.getI32Type();
727 Type intrinsicInType = numBits <= 32
728 ? (Type)rewriter.getIntegerType(numBits)
729 : (Type)VectorType::get(numBits / 32, i32);
730 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
731 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
732 loc, llvmIntrinsicInType, llvmInput);
733 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
734 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
735 // Add in the zeros here.
736 if (numBits < 32)
737 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
738 operands.push_back(castInput);
739}
740
741/// Push the output operand. For many cases this is only pushing the output in
742/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
743/// since the same numbers of VGPRs is used, we need to decide if to store the
744/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
745/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
746/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
747/// as the instructions have been changed to return fewer registers instead.
748static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
749 Location loc,
750 const TypeConverter *typeConverter,
751 Value output, int32_t subwordOffset,
752 bool clamp, SmallVectorImpl<Value> &operands,
754 Type inputType = output.getType();
755 auto vectorType = dyn_cast<VectorType>(inputType);
756 Type elemType = vectorType.getElementType();
757 operands.push_back(output);
758 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
759 attrs.push_back(
760 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
761 } else if (elemType.isInteger(32)) {
762 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
763 }
764}
765
766/// Return true if `type` is the E5M2 variant of an 8-bit float that is
767/// supported by the `_bf8` instructions on the given `chipset`.
768static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
769 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
770 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
771}
772
773/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
774/// supported by the `_fp8` instructions on the given `chipset`.
775static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
776 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
777 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
778}
779
780/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
781/// if one exists. This includes checking to ensure the intrinsic is supported
782/// on the architecture you are compiling for.
783static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
784 Chipset chipset) {
785 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
786 b = mfma.getBlocks();
787 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
788 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
789
790 if (sourceElem.isF32() && destElem.isF32()) {
791 if (mfma.getReducePrecision() && chipset >= kGfx942) {
792 if (m == 32 && n == 32 && k == 4 && b == 1)
793 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
794 if (m == 16 && n == 16 && k == 8 && b == 1)
795 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
796 }
797 if (m == 32 && n == 32 && k == 1 && b == 2)
798 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
799 if (m == 16 && n == 16 && k == 1 && b == 4)
800 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
801 if (m == 4 && n == 4 && k == 1 && b == 16)
802 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
803 if (m == 32 && n == 32 && k == 2 && b == 1)
804 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
805 if (m == 16 && n == 16 && k == 4 && b == 1)
806 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
807 }
808
809 if (sourceElem.isF16() && destElem.isF32()) {
810 if (chipset >= kGfx950) {
811 if (m == 32 && n == 32 && k == 16 && b == 1)
812 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
813 if (m == 16 && n == 16 && k == 32 && b == 1)
814 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
815 }
816 if (m == 32 && n == 32 && k == 4 && b == 2)
817 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
818 if (m == 16 && n == 16 && k == 4 && b == 4)
819 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
820 if (m == 4 && n == 4 && k == 4 && b == 16)
821 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
822 if (m == 32 && n == 32 && k == 8 && b == 1)
823 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
824 if (m == 16 && n == 16 && k == 16 && b == 1)
825 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
826 }
827
828 if (sourceElem.isBF16() && destElem.isF32()) {
829 if (chipset >= kGfx950) {
830 if (m == 32 && n == 32 && k == 16 && b == 1)
831 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
832 if (m == 16 && n == 16 && k == 32 && b == 1)
833 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
834 }
835 if (chipset >= kGfx90a) {
836 if (m == 32 && n == 32 && k == 4 && b == 2)
837 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
838 if (m == 16 && n == 16 && k == 4 && b == 4)
839 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
840 if (m == 4 && n == 4 && k == 4 && b == 16)
841 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
842 if (m == 32 && n == 32 && k == 8 && b == 1)
843 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
844 if (m == 16 && n == 16 && k == 16 && b == 1)
845 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
846 }
847 if (m == 32 && n == 32 && k == 2 && b == 2)
848 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
849 if (m == 16 && n == 16 && k == 2 && b == 4)
850 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
851 if (m == 4 && n == 4 && k == 2 && b == 16)
852 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
853 if (m == 32 && n == 32 && k == 4 && b == 1)
854 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
855 if (m == 16 && n == 16 && k == 8 && b == 1)
856 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
857 }
858
859 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
860 if (chipset >= kGfx950) {
861 if (m == 32 && n == 32 && k == 32 && b == 1)
862 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
863 if (m == 16 && n == 16 && k == 64 && b == 1)
864 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
865 }
866 if (m == 32 && n == 32 && k == 4 && b == 2)
867 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
868 if (m == 16 && n == 16 && k == 4 && b == 4)
869 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
870 if (m == 4 && n == 4 && k == 4 && b == 16)
871 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
872 if (m == 32 && n == 32 && k == 8 && b == 1)
873 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
874 if (m == 16 && n == 16 && k == 16 && b == 1)
875 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
876 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
877 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
878 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
879 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
880 }
881
882 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
883 if (m == 16 && n == 16 && k == 4 && b == 1)
884 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
885 if (m == 4 && n == 4 && k == 4 && b == 4)
886 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
887 }
888
889 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
890 // Known to be correct because there are no scalar f8 instructions and
891 // because a length mismatch will have been caught by the verifier.
892 Type sourceBElem =
893 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
894 if (m == 16 && n == 16 && k == 32 && b == 1) {
895 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
896 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
897 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
898 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
899 }
900 if (m == 32 && n == 32 && k == 16 && b == 1) {
901 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
902 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
903 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
904 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
905 }
906 }
907
908 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
909 Type sourceBElem =
910 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
911 if (m == 16 && n == 16 && k == 32 && b == 1) {
912 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
913 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
914 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
915 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
916 }
917 if (m == 32 && n == 32 && k == 16 && b == 1) {
918 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
919 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
920 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
921 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
922 }
923 }
924
925 return std::nullopt;
926}
927
928static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
930 .Case([](Float8E4M3FNType) { return 0u; })
931 .Case([](Float8E5M2Type) { return 1u; })
932 .Case([](Float6E2M3FNType) { return 2u; })
933 .Case([](Float6E3M2FNType) { return 3u; })
934 .Case([](Float4E2M1FNType) { return 4u; })
935 .Default(std::nullopt);
936}
937
938/// If there is a scaled MFMA instruction for the input element types `aType`
939/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
940/// blocks) on the given `chipset`, return a tuple consisting of the
941/// OperationName of the intrinsic and the type codes that need to be passed to
942/// that intrinsic. Note that this is also used to implement some un-scaled
943/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
944/// MFMA with a scale of 0.
945static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
946mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
947 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
948 aType = getElementTypeOrSelf(aType);
949 bType = getElementTypeOrSelf(bType);
950 destType = getElementTypeOrSelf(destType);
951
952 if (chipset < kGfx950)
953 return std::nullopt;
954 if (!isa<Float32Type>(destType))
955 return std::nullopt;
956
957 std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
958 std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
959 if (!aTypeCode || !bTypeCode)
960 return std::nullopt;
961
962 if (m == 32 && n == 32 && k == 64 && b == 1)
963 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
964 *aTypeCode, *bTypeCode};
965 if (m == 16 && n == 16 && k == 128 && b == 1)
966 return std::tuple{
967 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
968 *bTypeCode};
969
970 return std::nullopt;
971}
972
973static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
974mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
976 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
977 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
978 mfma.getBlocks(), chipset);
979}
980
981static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
982mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
983 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
984 smfma.getSourceB().getType(),
985 smfma.getDestC().getType(), smfma.getM(),
986 smfma.getN(), smfma.getK(), 1u, chipset);
987}
988
989/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
990/// for RDNA3/4 architectures.
991static std::optional<StringRef>
992wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
993 Type elemDestType, uint32_t k, bool isRDNA3) {
994 using fp8 = Float8E4M3FNType;
995 using bf8 = Float8E5M2Type;
996
997 // Handle k == 16 for RDNA3/4.
998 if (k == 16) {
999 // Common patterns for RDNA3 and RDNA4.
1000 if (elemSourceType.isF16() && elemDestType.isF32())
1001 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1002 if (elemSourceType.isBF16() && elemDestType.isF32())
1003 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1004 if (elemSourceType.isF16() && elemDestType.isF16())
1005 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1006 if (elemSourceType.isBF16() && elemDestType.isBF16())
1007 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1008 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1009 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1010
1011 // RDNA3 specific patterns.
1012 if (isRDNA3) {
1013 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1014 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1015 return std::nullopt;
1016 }
1017
1018 // RDNA4 specific patterns (fp8/bf8).
1019 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1020 elemDestType.isF32())
1021 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1022 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1023 elemDestType.isF32())
1024 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1025 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1026 elemDestType.isF32())
1027 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1028 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1029 elemDestType.isF32())
1030 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1031 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1032 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1033
1034 return std::nullopt;
1035 }
1036
1037 // Handle k == 32 for RDNA4.
1038 if (k == 32 && !isRDNA3) {
1039 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1040 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1041 }
1042
1043 return std::nullopt;
1044}
1045
1046/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1047/// for the gfx1250 architecture.
1048static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1049 Type elemBSourceType,
1050 Type elemDestType,
1051 uint32_t k) {
1052 using fp8 = Float8E4M3FNType;
1053 using bf8 = Float8E5M2Type;
1054
1055 if (k == 4) {
1056 if (elemSourceType.isF32() && elemDestType.isF32())
1057 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1058
1059 return std::nullopt;
1060 }
1061
1062 if (k == 32) {
1063 if (elemSourceType.isF16() && elemDestType.isF32())
1064 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1065 if (elemSourceType.isBF16() && elemDestType.isF32())
1066 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1067 if (elemSourceType.isF16() && elemDestType.isF16())
1068 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1069 if (elemSourceType.isBF16() && elemDestType.isBF16())
1070 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1071
1072 return std::nullopt;
1073 }
1074
1075 if (k == 64) {
1076 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1077 if (elemDestType.isF32())
1078 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1079 if (elemDestType.isF16())
1080 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1081 }
1082 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1083 if (elemDestType.isF32())
1084 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1085 if (elemDestType.isF16())
1086 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1087 }
1088 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1089 if (elemDestType.isF32())
1090 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1091 if (elemDestType.isF16())
1092 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1093 }
1094 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1095 if (elemDestType.isF32())
1096 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1097 if (elemDestType.isF16())
1098 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1099 }
1100 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1101 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1102
1103 return std::nullopt;
1104 }
1105
1106 if (k == 128) {
1107 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1108 if (elemDestType.isF32())
1109 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1110 if (elemDestType.isF16())
1111 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1112 }
1113 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1114 if (elemDestType.isF32())
1115 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1116 if (elemDestType.isF16())
1117 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1118 }
1119 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1120 if (elemDestType.isF32())
1121 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1122 if (elemDestType.isF16())
1123 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1124 }
1125 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1126 if (elemDestType.isF32())
1127 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1128 if (elemDestType.isF16())
1129 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1130 }
1131
1132 return std::nullopt;
1133 }
1134
1135 return std::nullopt;
1136}
1137
1138/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1139/// if one exists. This includes checking to ensure the intrinsic is supported
1140/// on the architecture you are compiling for.
1141static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1142 Chipset chipset) {
1143 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1144 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1145 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1146 Type elemSourceType = sourceVectorType.getElementType();
1147 Type elemBSourceType = sourceBVectorType.getElementType();
1148 Type elemDestType = destVectorType.getElementType();
1149
1150 const uint32_t k = wmma.getK();
1151 const bool isRDNA3 = chipset.majorVersion == 11;
1152 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1153
1154 // Handle RDNA3 and RDNA4.
1155 if (isRDNA3 || isRDNA4)
1156 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1157 k, isRDNA3);
1158
1159 // Handle gfx1250.
1160 if (chipset == kGfx1250)
1161 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1162 elemDestType, k);
1163
1164 return std::nullopt;
1165}
1166
1167namespace {
1168struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1169 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1170 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1171
1172 Chipset chipset;
1173
1174 LogicalResult
1175 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter) const override {
1177 Location loc = op.getLoc();
1178 Type outType = typeConverter->convertType(op.getDestD().getType());
1179 Type intrinsicOutType = outType;
1180 if (auto outVecType = dyn_cast<VectorType>(outType))
1181 if (outVecType.getElementType().isBF16())
1182 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1183
1184 if (chipset.majorVersion != 9 || chipset < kGfx908)
1185 return op->emitOpError("MFMA only supported on gfx908+");
1186 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1187 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1188 if (chipset < kGfx942)
1189 return op.emitOpError("negation unsupported on older than gfx942");
1190 getBlgpField |=
1191 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1192 }
1193 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1194 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1195 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1196 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1197 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1198
1199 bool isScaled =
1200 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1201 if (isScaled &&
1202 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1203 return op.emitOpError(
1204 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1205 "be scaled as those fields are used for type information");
1206 }
1207
1208 StringRef intrinsicName =
1209 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1210 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1211 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1212 bool allowBf16 = [&]() {
1213 if (chipset < kGfx950)
1214 return false;
1215 if (isScaled)
1216 return true;
1217 return intrinsicName.contains("16x16x32.bf16") ||
1218 intrinsicName.contains("32x32x16.bf16");
1219 }();
1220 OperationState loweredOp(loc, intrinsicName);
1221 loweredOp.addTypes(intrinsicOutType);
1222 loweredOp.addOperands({convertMFMAVectorOperand(
1223 rewriter, loc, adaptor.getSourceA(), allowBf16),
1225 rewriter, loc, adaptor.getSourceB(), allowBf16),
1226 adaptor.getDestC()});
1227 if (isScaled) {
1228 Value zero = createI32Constant(rewriter, loc, 0);
1229 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1230 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1231 createI32Constant(rewriter, loc, bTypeCode),
1232 /*scale A byte=*/zero, /*scale A=*/zero,
1233 /*scale B byte=*/zero, /*scale B=*/zero});
1234 } else {
1235 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1236 createI32Constant(rewriter, loc, op.getAbid()),
1237 createI32Constant(rewriter, loc, getBlgpField)});
1238 };
1239 Value lowered = rewriter.create(loweredOp)->getResult(0);
1240 if (outType != intrinsicOutType)
1241 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1242 rewriter.replaceOp(op, lowered);
1243 return success();
1244 }
1245};
1246
1247struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1248 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1249 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1250
1251 Chipset chipset;
1252
1253 LogicalResult
1254 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1255 ConversionPatternRewriter &rewriter) const override {
1256 Location loc = op.getLoc();
1257 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1258
1259 if (chipset.majorVersion != 9 || chipset < kGfx950)
1260 return op->emitOpError("scaled MFMA only supported on gfx908+");
1261 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1262 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1263 if (!maybeScaledIntrinsic.has_value())
1264 return op.emitOpError(
1265 "no intrinsic matching scaled MFMA size on given chipset");
1266
1267 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1268 OperationState loweredOp(loc, intrinsicName);
1269 loweredOp.addTypes(intrinsicOutType);
1270 loweredOp.addOperands(
1271 {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1272 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1273 adaptor.getDestC()});
1274 Value scalesIdxA =
1275 createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1276 Value scalesIdxB =
1277 createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1278 loweredOp.addOperands(
1279 {createI32Constant(rewriter, loc, aTypeCode),
1280 createI32Constant(rewriter, loc, bTypeCode),
1281 /*scales idx A=*/scalesIdxA,
1282 /*scales A*/
1283 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1284 /*scales idx B=*/scalesIdxB,
1285 /*scales B*/
1286 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1287 Value lowered = rewriter.create(loweredOp)->getResult(0);
1288 rewriter.replaceOp(op, lowered);
1289 return success();
1290 }
1291};
1292
1293struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1294 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1295 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1296
1297 Chipset chipset;
1298
1299 LogicalResult
1300 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1301 ConversionPatternRewriter &rewriter) const override {
1302 Location loc = op.getLoc();
1303 auto outType =
1304 typeConverter->convertType<VectorType>(op.getDestD().getType());
1305 if (!outType)
1306 return rewriter.notifyMatchFailure(op, "type conversion failed");
1307
1308 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1309 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1310
1311 bool isGFX1250 = chipset >= kGfx1250;
1312
1313 // The WMMA operations represent vectors of bf16s as vectors of i16s
1314 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1315 // bitcast them back.
1316 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1317 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1318 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1319 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1320 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1321 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1322 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1323 VectorType rawOutType = outType;
1324 if (castOutToI16)
1325 rawOutType = outType.clone(rewriter.getI16Type());
1326 Value a = adaptor.getSourceA();
1327 if (castAToI16)
1328 a = LLVM::BitcastOp::create(rewriter, loc,
1329 aType.clone(rewriter.getI16Type()), a);
1330 Value b = adaptor.getSourceB();
1331 if (castBToI16)
1332 b = LLVM::BitcastOp::create(rewriter, loc,
1333 bType.clone(rewriter.getI16Type()), b);
1334 Value destC = adaptor.getDestC();
1335 if (castDestCToI16)
1336 destC = LLVM::BitcastOp::create(
1337 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1338
1339 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1340
1341 if (!maybeIntrinsic.has_value())
1342 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1343
1344 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1345 return op.emitOpError("subwordOffset not supported on gfx12+");
1346
1347 SmallVector<Value, 4> operands;
1348 SmallVector<NamedAttribute, 4> attrs;
1349 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1350 op.getSourceA(), operands, attrs, "signA");
1351 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1352 op.getSourceB(), operands, attrs, "signB");
1353 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1354 op.getSubwordOffset(), op.getClamp(), operands,
1355 attrs);
1356
1357 OperationState loweredOp(loc, *maybeIntrinsic);
1358 loweredOp.addTypes(rawOutType);
1359 loweredOp.addOperands(operands);
1360 loweredOp.addAttributes(attrs);
1361 Operation *lowered = rewriter.create(loweredOp);
1362
1363 Operation *maybeCastBack = lowered;
1364 if (rawOutType != outType)
1365 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1366 lowered->getResult(0));
1367 rewriter.replaceOp(op, maybeCastBack->getResults());
1368
1369 return success();
1370 }
1371};
1372
1373struct TransposeLoadOpLowering
1374 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1375 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1376 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1377
1378 Chipset chipset;
1379
1380 LogicalResult
1381 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1382 ConversionPatternRewriter &rewriter) const override {
1383 if (chipset != kGfx950)
1384 return op.emitOpError("Non-gfx950 chipset not supported");
1385
1386 Location loc = op.getLoc();
1387 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1388
1389 // Elements in subbyte memrefs are stored non-contiguously,
1390 // reject if source is sub-byte memref. Use emulated memrefs instead.
1391 size_t srcElementSize =
1392 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1393 if (srcElementSize < 8)
1394 return op.emitOpError("Expect source memref to have at least 8 bits "
1395 "element size, got ")
1396 << srcElementSize;
1397
1398 auto resultType = cast<VectorType>(op.getResult().getType());
1399 Value srcPtr =
1400 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1401 (adaptor.getSrcIndices()));
1402
1403 size_t numElements = resultType.getNumElements();
1404 size_t elementTypeSize =
1405 resultType.getElementType().getIntOrFloatBitWidth();
1406
1407 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1408 // the element size is smaller than 16 bits.
1409 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1410 rewriter.getIntegerType(32));
1411 Type llvmResultType = typeConverter->convertType(resultType);
1412
1413 switch (elementTypeSize) {
1414 case 4: {
1415 assert(numElements == 16);
1416 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1417 rocdlResultType, srcPtr);
1418 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1419 break;
1420 }
1421 case 6: {
1422 assert(numElements == 16);
1423 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1424 rocdlResultType, srcPtr);
1425 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1426 break;
1427 }
1428 case 8: {
1429 assert(numElements == 8);
1430 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1431 rocdlResultType, srcPtr);
1432 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1433 break;
1434 }
1435 case 16: {
1436 assert(numElements == 4);
1437 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1438 srcPtr);
1439 break;
1440 }
1441 default:
1442 return op.emitOpError("Unsupported element size for transpose load");
1443 }
1444 return success();
1445 }
1446};
1447
1448struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1449 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1450 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1451
1452 Chipset chipset;
1453
1454 LogicalResult
1455 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1456 ConversionPatternRewriter &rewriter) const override {
1457 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1458 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1459
1460 Location loc = op.getLoc();
1461
1462 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1463 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1464
1465 // TODO: instead of only transfering one element per thread, we could
1466 // augment it to transfer multiple elements per thread by issuing multiple
1467 // `global_load_lds` instructions.
1468 Type transferType = op.getTransferType();
1469 int loadWidth = [&]() -> int {
1470 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1471 return (transferVectorType.getNumElements() *
1472 transferVectorType.getElementTypeBitWidth()) /
1473 8;
1474 }
1475 return transferType.getIntOrFloatBitWidth() / 8;
1476 }();
1477
1478 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1479 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1480 return op.emitOpError("chipset unsupported element size");
1481
1482 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1483 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1484 "16-byte load widths are only supported on gfx950");
1485
1486 Value srcPtr =
1487 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1488 (adaptor.getSrcIndices()));
1489 Value dstPtr =
1490 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1491 (adaptor.getDstIndices()));
1492
1493 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1494 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1495 /*offset=*/rewriter.getI32IntegerAttr(0),
1496 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1497 ArrayAttr{});
1498
1499 return success();
1500 }
1501};
1502
1503namespace {
1504struct ExtPackedFp8OpLowering final
1505 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1506 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1507 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1508 chipset(chipset) {}
1509 Chipset chipset;
1510
1511 LogicalResult
1512 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1513 ConversionPatternRewriter &rewriter) const override;
1514};
1515
1516struct ScaledExtPackedMatrixOpLowering final
1517 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1518 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1519 Chipset chipset)
1520 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1521 chipset(chipset) {}
1522 Chipset chipset;
1523
1524 LogicalResult
1525 matchAndRewrite(ScaledExtPackedMatrixOp op,
1526 ScaledExtPackedMatrixOpAdaptor adaptor,
1527 ConversionPatternRewriter &rewriter) const override;
1528};
1529
1530struct PackedTrunc2xFp8OpLowering final
1531 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1532 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1533 Chipset chipset)
1534 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1535 chipset(chipset) {}
1536 Chipset chipset;
1537
1538 LogicalResult
1539 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1540 ConversionPatternRewriter &rewriter) const override;
1541};
1542
1543struct PackedStochRoundFp8OpLowering final
1544 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1545 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1546 Chipset chipset)
1547 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1548 chipset(chipset) {}
1549 Chipset chipset;
1550
1551 LogicalResult
1552 matchAndRewrite(PackedStochRoundFp8Op op,
1553 PackedStochRoundFp8OpAdaptor adaptor,
1554 ConversionPatternRewriter &rewriter) const override;
1555};
1556
1557struct ScaledExtPackedOpLowering final
1558 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1559 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1560 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1561 chipset(chipset) {}
1562 Chipset chipset;
1563
1564 LogicalResult
1565 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1566 ConversionPatternRewriter &rewriter) const override;
1567};
1568
1569struct PackedScaledTruncOpLowering final
1570 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1571 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1572 Chipset chipset)
1573 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1574 chipset(chipset) {}
1575 Chipset chipset;
1576
1577 LogicalResult
1578 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1579 ConversionPatternRewriter &rewriter) const override;
1580};
1581
1582} // end namespace
1583
1584LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1585 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1586 ConversionPatternRewriter &rewriter) const {
1587 Location loc = op.getLoc();
1588 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1589 return rewriter.notifyMatchFailure(
1590 loc, "Fp8 conversion instructions are not available on target "
1591 "architecture and their emulation is not implemented");
1592 Type v4i8 =
1593 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1594 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1595 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1596
1597 Value source = adaptor.getSource();
1598 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1599 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1600 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1601 // Extend to a v4i8
1602 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1603 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1604 if (!sourceVecType) {
1605 longVec = LLVM::InsertElementOp::create(
1606 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1607 } else {
1608 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1609 Value idx = createI32Constant(rewriter, loc, i);
1610 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1611 longVec =
1612 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1613 }
1614 }
1615 source = longVec;
1616 }
1617 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1618 if (resultVecType) {
1619 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1620 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1621 op.getIndex());
1622 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1623 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1624 op.getIndex());
1625 }
1626 } else {
1627 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1628 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1629 op.getIndex());
1630 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1631 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1632 op.getIndex());
1633 }
1634 }
1635 return success();
1636}
1637
1638int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1639 int32_t firstScaleByte) {
1640 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1641 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
1642 // firstScaleByte are merged into a single attribute scaleSel. This is how
1643 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1644 // attribute but is derifed from firstScaleLane).
1645 assert(llvm::is_contained({16, 32}, blockSize));
1646 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1647
1648 const bool isFp8 = bitWidth == 8;
1649 const bool isBlock16 = blockSize == 16;
1650
1651 if (!isFp8) {
1652 int32_t bit0 = isBlock16;
1653 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1654 int32_t bit1 = (firstScaleByte == 2) << 1;
1655 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1656 int32_t bit2 = scaleWaveHalf << 2;
1657 return bit2 | bit1 | bit0;
1658 }
1659
1660 int32_t bit0 = isBlock16;
1661 // firstScaleByte is guaranteed to be defined by two bits.
1662 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1663 int32_t bits2and1 = firstScaleByte << 1;
1664 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1665 int32_t bit3 = scaleWaveHalf << 3;
1666 int32_t bits = bit3 | bits2and1 | bit0;
1667 // These are invalid cases.
1668 assert(!llvm::is_contained(
1669 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1670 return bits;
1671}
1672
1673static std::optional<StringRef>
1674scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1675 using fp4 = Float4E2M1FNType;
1676 using fp8 = Float8E4M3FNType;
1677 using bf8 = Float8E5M2Type;
1678 using fp6 = Float6E2M3FNType;
1679 using bf6 = Float6E3M2FNType;
1680 if (isa<fp4>(srcElemType)) {
1681 if (destElemType.isF16())
1682 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1683 if (destElemType.isBF16())
1684 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1685 if (destElemType.isF32())
1686 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1687 return std::nullopt;
1688 }
1689 if (isa<fp8>(srcElemType)) {
1690 if (destElemType.isF16())
1691 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1692 if (destElemType.isBF16())
1693 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1694 if (destElemType.isF32())
1695 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1696 return std::nullopt;
1697 }
1698 if (isa<bf8>(srcElemType)) {
1699 if (destElemType.isF16())
1700 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1701 if (destElemType.isBF16())
1702 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1703 if (destElemType.isF32())
1704 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1705 return std::nullopt;
1706 }
1707 if (isa<fp6>(srcElemType)) {
1708 if (destElemType.isF16())
1709 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1710 if (destElemType.isBF16())
1711 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1712 if (destElemType.isF32())
1713 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1714 return std::nullopt;
1715 }
1716 if (isa<bf6>(srcElemType)) {
1717 if (destElemType.isF16())
1718 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1719 if (destElemType.isBF16())
1720 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1721 if (destElemType.isF32())
1722 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1723 return std::nullopt;
1724 }
1725 llvm_unreachable("invalid combination of element types for packed conversion "
1726 "instructions");
1727}
1728
1729LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
1730 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
1731 ConversionPatternRewriter &rewriter) const {
1732 using fp4 = Float4E2M1FNType;
1733 using fp8 = Float8E4M3FNType;
1734 using bf8 = Float8E5M2Type;
1735 using fp6 = Float6E2M3FNType;
1736 using bf6 = Float6E3M2FNType;
1737 Location loc = op.getLoc();
1738 if (chipset != kGfx1250) {
1739 return rewriter.notifyMatchFailure(
1740 loc,
1741 "Scaled fp packed conversion instructions are not available on target "
1742 "architecture and their emulation is not implemented");
1743 }
1744 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
1745 // is being selected.
1746 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
1747 int32_t firstScaleByte = op.getFirstScaleByte();
1748 int32_t blockSize = op.getBlockSize();
1749 auto sourceType = cast<VectorType>(op.getSource().getType());
1750 auto srcElemType = cast<FloatType>(sourceType.getElementType());
1751 unsigned bitWidth = srcElemType.getWidth();
1752
1753 auto targetType = cast<VectorType>(op.getResult().getType());
1754 auto destElemType = cast<FloatType>(targetType.getElementType());
1755
1756 IntegerType i32 = rewriter.getI32Type();
1757 Value source = adaptor.getSource();
1758 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
1759 Type packedType = nullptr;
1760 if (isa<fp4>(srcElemType)) {
1761 packedType = i32;
1762 packedType = getTypeConverter()->convertType(packedType);
1763 } else if (isa<fp8, bf8>(srcElemType)) {
1764 packedType = VectorType::get(2, i32);
1765 packedType = getTypeConverter()->convertType(packedType);
1766 } else if (isa<fp6, bf6>(srcElemType)) {
1767 packedType = VectorType::get(3, i32);
1768 packedType = getTypeConverter()->convertType(packedType);
1769 } else {
1770 llvm_unreachable("invalid element type for packed scaled ext");
1771 }
1772
1773 if (!packedType || !llvmResultType) {
1774 return rewriter.notifyMatchFailure(op, "type conversion failed");
1775 }
1776
1777 std::optional<StringRef> maybeIntrinsic =
1778 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
1779 if (!maybeIntrinsic.has_value())
1780 return op.emitOpError(
1781 "no intrinsic matching packed scaled conversion on the given chipset");
1782
1783 int32_t scaleSel =
1784 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
1785 Value castedScale =
1786 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
1787 Value castedSource =
1788 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
1789
1790 OperationState loweredOp(loc, *maybeIntrinsic);
1791 loweredOp.addTypes({llvmResultType});
1792 loweredOp.addOperands({castedSource, castedScale});
1793
1794 SmallVector<NamedAttribute, 1> attrs;
1795 attrs.push_back(
1796 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
1797
1798 loweredOp.addAttributes(attrs);
1799 Operation *lowered = rewriter.create(loweredOp);
1800 rewriter.replaceOp(op, lowered);
1801
1802 return success();
1803}
1804
1805LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1806 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1807 ConversionPatternRewriter &rewriter) const {
1808 Location loc = op.getLoc();
1809 if (chipset != kGfx950)
1810 return rewriter.notifyMatchFailure(
1811 loc, "Scaled fp conversion instructions are not available on target "
1812 "architecture and their emulation is not implemented");
1813 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1814
1815 Value source = adaptor.getSource();
1816 Value scale = adaptor.getScale();
1817
1818 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1819 Type sourceElemType = sourceVecType.getElementType();
1820 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1821 Type destElemType = destVecType.getElementType();
1822
1823 VectorType packedVecType;
1824 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1825 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1826 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1827 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1828 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1829 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1830 } else {
1831 llvm_unreachable("invalid element type for scaled ext");
1832 }
1833
1834 // Extend to a packedVectorType
1835 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1836 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1837 if (!sourceVecType) {
1838 longVec = LLVM::InsertElementOp::create(
1839 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1840 } else {
1841 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1842 Value idx = createI32Constant(rewriter, loc, i);
1843 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1844 longVec =
1845 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1846 }
1847 }
1848 source = longVec;
1849 }
1850 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1851
1852 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1853 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1854 op, destVecType, i32Source, scale, op.getIndex());
1855 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1856 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1857 op, destVecType, i32Source, scale, op.getIndex());
1858 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1859 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1860 op, destVecType, i32Source, scale, op.getIndex());
1861 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1862 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1863 op, destVecType, i32Source, scale, op.getIndex());
1864 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1865 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1866 op, destVecType, i32Source, scale, op.getIndex());
1867 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1868 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1869 op, destVecType, i32Source, scale, op.getIndex());
1870 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1871 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1872 op, destVecType, i32Source, scale, op.getIndex());
1873 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1874 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1875 op, destVecType, i32Source, scale, op.getIndex());
1876 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1877 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1878 op, destVecType, i32Source, scale, op.getIndex());
1879 else
1880 return failure();
1881
1882 return success();
1883}
1884
1885LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1886 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1887 ConversionPatternRewriter &rewriter) const {
1888 Location loc = op.getLoc();
1889 if (chipset != kGfx950)
1890 return rewriter.notifyMatchFailure(
1891 loc, "Scaled fp conversion instructions are not available on target "
1892 "architecture and their emulation is not implemented");
1893 Type v2i16 = getTypeConverter()->convertType(
1894 VectorType::get(2, rewriter.getI16Type()));
1895 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1896
1897 Type resultType = op.getResult().getType();
1898 Type resultElemType = getElementTypeOrSelf(resultType);
1899 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1900 Type sourceElemType = sourceVecType.getElementType();
1901
1902 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1903
1904 Value source = adaptor.getSource();
1905 Value scale = adaptor.getScale();
1906 Value existing = adaptor.getExisting();
1907 if (existing)
1908 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1909 else
1910 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1911
1912 if (sourceVecType.getNumElements() < 2) {
1913 Value c0 = createI32Constant(rewriter, loc, 0);
1914 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1915 VectorType v2 = VectorType::get(2, sourceElemType);
1916 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1917 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1918 }
1919
1920 Value sourceA, sourceB;
1921 if (sourceElemType.isF32()) {
1922 Value c0 = createI32Constant(rewriter, loc, 0);
1923 Value c1 = createI32Constant(rewriter, loc, 1);
1924 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1925 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1926 }
1927
1928 Value result;
1929 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1930 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1931 existing, sourceA, sourceB,
1932 scale, op.getIndex());
1933 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1934 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1935 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1936 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1937 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1938 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1939 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1940 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1941 existing, sourceA, sourceB,
1942 scale, op.getIndex());
1943 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1944 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1945 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1946 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1947 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1948 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1949 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1950 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1951 existing, sourceA, sourceB,
1952 scale, op.getIndex());
1953 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1954 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1955 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1956 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1957 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1958 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1959 else
1960 return failure();
1961
1962 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1963 op, getTypeConverter()->convertType(resultType), result);
1964 return success();
1965}
1966
1967LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1968 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1969 ConversionPatternRewriter &rewriter) const {
1970 Location loc = op.getLoc();
1971 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1972 return rewriter.notifyMatchFailure(
1973 loc, "Fp8 conversion instructions are not available on target "
1974 "architecture and their emulation is not implemented");
1975 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1976
1977 Type resultType = op.getResult().getType();
1978 Type resultElemType = getElementTypeOrSelf(resultType);
1979
1980 Value sourceA = adaptor.getSourceA();
1981 Value sourceB = adaptor.getSourceB();
1982 if (!sourceB)
1983 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
1984 Value existing = adaptor.getExisting();
1985 if (existing)
1986 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1987 else
1988 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1989
1990 Value result;
1991 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1992 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1993 existing, op.getWordIndex());
1994 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1995 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1996 existing, op.getWordIndex());
1997
1998 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1999 op, getTypeConverter()->convertType(resultType), result);
2000 return success();
2001}
2002
2003LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2004 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2005 ConversionPatternRewriter &rewriter) const {
2006 Location loc = op.getLoc();
2007 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2008 return rewriter.notifyMatchFailure(
2009 loc, "Fp8 conversion instructions are not available on target "
2010 "architecture and their emulation is not implemented");
2011 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2012
2013 Type resultType = op.getResult().getType();
2014 Type resultElemType = getElementTypeOrSelf(resultType);
2015
2016 Value source = adaptor.getSource();
2017 Value stoch = adaptor.getStochiasticParam();
2018 Value existing = adaptor.getExisting();
2019 if (existing)
2020 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2021 else
2022 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2023
2024 Value result;
2025 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2026 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2027 existing, op.getStoreIndex());
2028 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2029 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2030 existing, op.getStoreIndex());
2031
2032 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2033 op, getTypeConverter()->convertType(resultType), result);
2034 return success();
2035}
2036
2037// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2038// operation into the corresponding ROCDL instructions.
2039struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2040 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2041 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2042 Chipset chipset;
2043
2044 LogicalResult
2045 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2046 ConversionPatternRewriter &rewriter) const override {
2047
2048 // Convert the source operand to the corresponding LLVM type
2049 Location loc = DppOp.getLoc();
2050 Value src = adaptor.getSrc();
2051 Value old = adaptor.getOld();
2052 Type srcType = src.getType();
2053 Type oldType = old.getType();
2054 Type llvmType = nullptr;
2055 if (srcType.getIntOrFloatBitWidth() < 32) {
2056 llvmType = rewriter.getI32Type();
2057 } else if (isa<FloatType>(srcType)) {
2058 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2059 ? rewriter.getF32Type()
2060 : rewriter.getF64Type();
2061 } else if (isa<IntegerType>(srcType)) {
2062 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2063 ? rewriter.getI32Type()
2064 : rewriter.getI64Type();
2065 }
2066 auto llvmSrcIntType = typeConverter->convertType(
2067 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2068
2069 // If the source type is less of 32, use bitcast to convert it to i32.
2070 auto convertOperand = [&](Value operand, Type operandType) {
2071 if (operandType.getIntOrFloatBitWidth() <= 16) {
2072 if (llvm::isa<FloatType>(operandType)) {
2073 operand =
2074 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2075 }
2076 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2077 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2078 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2079 operand =
2080 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2081 createI32Constant(rewriter, loc, 0));
2082 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2083 }
2084 return operand;
2085 };
2086
2087 src = convertOperand(src, srcType);
2088 old = convertOperand(old, oldType);
2089
2090 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2091 enum DppCtrl : unsigned {
2092 ROW_SHL0 = 0x100,
2093 ROW_SHR0 = 0x110,
2094 ROW_ROR0 = 0x120,
2095 WAVE_SHL1 = 0x130,
2096 WAVE_ROL1 = 0x134,
2097 WAVE_SHR1 = 0x138,
2098 WAVE_ROR1 = 0x13C,
2099 ROW_MIRROR = 0x140,
2100 ROW_HALF_MIRROR = 0x141,
2101 BCAST15 = 0x142,
2102 BCAST31 = 0x143,
2103 };
2104
2105 auto kind = DppOp.getKind();
2106 auto permArgument = DppOp.getPermArgument();
2107 uint32_t DppCtrl = 0;
2108
2109 switch (kind) {
2110
2111 case DPPPerm::quad_perm:
2112 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
2113 int32_t i = 0;
2114 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2115 uint32_t num = elem.getInt();
2116 DppCtrl |= num << (i * 2);
2117 i++;
2118 }
2119 }
2120 break;
2121 case DPPPerm::row_shl:
2122 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
2123 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2124 }
2125 break;
2126 case DPPPerm::row_shr:
2127 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
2128 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2129 }
2130 break;
2131 case DPPPerm::row_ror:
2132 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
2133 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2134 }
2135 break;
2136 case DPPPerm::wave_shl:
2137 DppCtrl = DppCtrl::WAVE_SHL1;
2138 break;
2139 case DPPPerm::wave_shr:
2140 DppCtrl = DppCtrl::WAVE_SHR1;
2141 break;
2142 case DPPPerm::wave_rol:
2143 DppCtrl = DppCtrl::WAVE_ROL1;
2144 break;
2145 case DPPPerm::wave_ror:
2146 DppCtrl = DppCtrl::WAVE_ROR1;
2147 break;
2148 case DPPPerm::row_mirror:
2149 DppCtrl = DppCtrl::ROW_MIRROR;
2150 break;
2151 case DPPPerm::row_half_mirror:
2152 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2153 break;
2154 case DPPPerm::row_bcast_15:
2155 DppCtrl = DppCtrl::BCAST15;
2156 break;
2157 case DPPPerm::row_bcast_31:
2158 DppCtrl = DppCtrl::BCAST31;
2159 break;
2160 }
2161
2162 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2163 // constants
2164 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2165 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2166 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2167
2168 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2169 auto dppMovOp =
2170 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2171 rowMask, bankMask, boundCtrl);
2172
2173 Value result = dppMovOp.getRes();
2174 if (srcType.getIntOrFloatBitWidth() < 32) {
2175 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2176 if (!llvm::isa<IntegerType>(srcType)) {
2177 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2178 }
2179 }
2180
2181 // We are replacing the AMDGPU_DPPOp instruction with the new
2182 // ROCDL_DPPMovOp instruction
2183 rewriter.replaceOp(DppOp, ValueRange(result));
2184 return success();
2185 }
2186};
2187
2188struct AMDGPUSwizzleBitModeLowering
2189 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2191
2192 LogicalResult
2193 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2194 ConversionPatternRewriter &rewriter) const override {
2195 Location loc = op.getLoc();
2196 Type i32 = rewriter.getI32Type();
2197 Value src = adaptor.getSrc();
2198 SmallVector<Value> decomposed =
2199 LLVM::decomposeValue(rewriter, loc, src, i32);
2200 unsigned andMask = op.getAndMask();
2201 unsigned orMask = op.getOrMask();
2202 unsigned xorMask = op.getXorMask();
2203
2204 // bit 15 is 0 for the BitMode swizzle.
2205 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2206 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2207 Value maskValue = createI32Constant(rewriter, loc, mask);
2208 SmallVector<Value> swizzled;
2209 for (Value v : decomposed) {
2210 Value res =
2211 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2212 swizzled.emplace_back(res);
2213 }
2214
2215 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2216 rewriter.replaceOp(op, result);
2217 return success();
2218 }
2219};
2220
2221struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2223
2224 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2225 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2226 Chipset chipset;
2227
2228 LogicalResult
2229 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2230 ConversionPatternRewriter &rewriter) const override {
2231 if (chipset < kGfx950)
2232 return op->emitOpError("permlane_swap is only supported on gfx950+");
2233
2234 Location loc = op.getLoc();
2235 Type i32 = rewriter.getI32Type();
2236 Value src = adaptor.getSrc();
2237 unsigned rowLength = op.getRowLength();
2238 bool fi = op.getFetchInactive();
2239 bool boundctrl = op.getBoundCtrl();
2240
2241 SmallVector<Value> decomposed =
2242 LLVM::decomposeValue(rewriter, loc, src, i32);
2243
2244 SmallVector<Value> permuted;
2245 for (Value v : decomposed) {
2246 Value res;
2247 Type i32pair = LLVM::LLVMStructType::getLiteral(
2248 rewriter.getContext(), {v.getType(), v.getType()});
2249
2250 if (rowLength == 16)
2251 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2252 boundctrl);
2253 else if (rowLength == 32)
2254 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2255 boundctrl);
2256 else
2257 llvm_unreachable("unsupported row length");
2258
2259 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2260 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2261
2262 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2263 LLVM::ICmpPredicate::eq, vdst0, v);
2264
2265 // Per `permlane(16|32)` semantics: if the first extracted element equals
2266 // 'v', the result is the second element; otherwise it is the first.
2267 Value vdstNew =
2268 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2269 permuted.emplace_back(vdstNew);
2270 }
2271
2272 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2273 rewriter.replaceOp(op, result);
2274 return success();
2275 }
2276};
2277
2278struct AMDGPUMakeDmaBaseLowering
2279 : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2281
2282 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2283 : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
2284 Chipset chipset;
2285
2286 LogicalResult
2287 matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
2288 ConversionPatternRewriter &rewriter) const override {
2289 if (chipset < kGfx1250)
2290 return op->emitOpError("make_dma_base is only supported on gfx1250");
2291
2292 Location loc = op.getLoc();
2293
2294 ValueRange ldsIndices = adaptor.getLdsIndices();
2295 Value lds = adaptor.getLds();
2296 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2297
2298 Value ldsPtr =
2299 getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
2300
2301 ValueRange globalIndices = adaptor.getGlobalIndices();
2302 Value global = adaptor.getGlobal();
2303 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2304
2305 Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
2306 global, globalIndices);
2307
2308 Type i32 = rewriter.getI32Type();
2309 Type i64 = rewriter.getI64Type();
2310
2311 Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2312 Value castForGlobalAddr =
2313 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2314
2315 Value lowHalf =
2316 LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2317
2318 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2319 createI64Constant(rewriter, loc, 32));
2320
2321 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2322
2323 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2324 Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2325
2326 Value typeField = createI32Constant(rewriter, loc, 2 << 30);
2327 Value highHalfPlusType =
2328 LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
2329
2330 Value c0 = createI32Constant(rewriter, loc, 0);
2331 Value c1 = createI32Constant(rewriter, loc, 1);
2332 Value c2 = createI32Constant(rewriter, loc, 2);
2333 Value c3 = createI32Constant(rewriter, loc, 3);
2334
2335 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2336 assert(v4i32 && "expected type conversion to succeed");
2337 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2338 result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
2339 result = LLVM::InsertElementOp::create(rewriter, loc, result,
2340 castForLdsAddr, c1);
2341 result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
2342 result = LLVM::InsertElementOp::create(rewriter, loc, result,
2343 highHalfPlusType, c3);
2344
2345 rewriter.replaceOp(op, result);
2346 return success();
2347 }
2348};
2349
2350struct AMDGPUMakeDmaDescriptorLowering
2351 : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
2353
2354 AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
2355 Chipset chipset)
2356 : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
2357 chipset(chipset) {}
2358 Chipset chipset;
2359
2360 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
2361
2362 Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2363 Value accumulator, Value value, int64_t shift) const {
2364 shift = shift % 32;
2365 Value shiftAmount;
2366 if (shift != 0) {
2367 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2368 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2369 }
2370
2371 if (matchPattern(accumulator, mlir::m_Zero()))
2372 return value;
2373
2374 return LLVM::OrOp::create(rewriter, loc, accumulator, value);
2375 }
2376
2377 Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2378 ConversionPatternRewriter &rewriter, Location loc,
2379 Value sgpr0) const {
2380 Value mask = op.getWorkgroupMask();
2381 if (!mask)
2382 return sgpr0;
2383
2384 Type i32 = rewriter.getI32Type();
2385 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2386 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2387 }
2388
2389 Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2390 ConversionPatternRewriter &rewriter, Location loc,
2391 Value sgpr0, ArrayRef<Value> consts) const {
2392 // Compute data_size.
2393 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2394 assert(
2395 llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
2396 "expected type width to be 8, 16, 32, or 64.");
2397 int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8);
2398 Value size = createI32Constant(rewriter, loc, dataSize);
2399 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2400 }
2401
2402 Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2403 ConversionPatternRewriter &rewriter, Location loc,
2404 Value sgpr0, ArrayRef<Value> consts) const {
2405 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
2406 if (!atomic_barrier_enable)
2407 return sgpr0;
2408
2409 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2410 }
2411
2412 Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2413 ConversionPatternRewriter &rewriter, Location loc,
2414 Value sgpr0, ArrayRef<Value> consts) const {
2415 bool iterate_enable = adaptor.getGlobalIncrement() != nullptr;
2416 if (!iterate_enable)
2417 return sgpr0;
2418
2419 // TODO: In future PR, add other required fields for iteration.
2420 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2421 }
2422
2423 Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2424 ConversionPatternRewriter &rewriter, Location loc,
2425 Value sgpr0, ArrayRef<Value> consts) const {
2426 bool pad_enable = op.getPadAmount() != nullptr;
2427 if (!pad_enable)
2428 return sgpr0;
2429
2430 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2431 }
2432
2433 Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm,
2434 ConversionPatternRewriter &rewriter, Location loc,
2435 Value sgpr0, ArrayRef<Value> consts) const {
2436 if (!op.getWorkgroupMask())
2437 return sgpr0;
2438
2439 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2440 }
2441
2442 Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2443 ConversionPatternRewriter &rewriter, Location loc,
2444 Value sgpr0, ArrayRef<Value> consts) const {
2445 bool pad_enable = op.getPadAmount() != nullptr;
2446 if (!pad_enable)
2447 return sgpr0;
2448
2449 IntegerType i32 = rewriter.getI32Type();
2450 Value padInterval = adaptor.getPadInterval();
2451 // pre-condition: padInterval can be a power of two between 2 and 256.
2452 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2453 padInterval, false);
2454 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2455 // post-condition: padInterval can be a value between 0 and 7.
2456 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2457 }
2458
2459 Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2460 ConversionPatternRewriter &rewriter, Location loc,
2461 Value sgpr0, ArrayRef<Value> consts) const {
2462 bool pad_enable = op.getPadAmount() != nullptr;
2463 if (!pad_enable)
2464 return sgpr0;
2465
2466 Value padAmount = adaptor.getPadAmount();
2467 // pre-condition: padAmount is a value between 1-128.
2468 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2469 // post-condition: padAmount is a value between 0-127.
2470 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2471 }
2472
2473 Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2474 ConversionPatternRewriter &rewriter,
2475 Location loc, Value sgpr1,
2476 ArrayRef<Value> consts) const {
2477 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
2478 if (!atomic_barrier_enable)
2479 return sgpr1;
2480
2481 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2482 auto barrierAddressTy =
2483 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2484 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2485 atomicBarrierAddress =
2486 getStridedElementPtr(rewriter, loc, barrierAddressTy,
2487 atomicBarrierAddress, atomicBarrierIndices);
2488 IntegerType i32 = rewriter.getI32Type();
2489 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
2490 // that the 3 LSBs are zero.
2491 atomicBarrierAddress =
2492 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2493 atomicBarrierAddress =
2494 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2495 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
2496 atomicBarrierAddress =
2497 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2498 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2499 }
2500
2501 std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
2502 OpAdaptor adaptor,
2503 ConversionPatternRewriter &rewriter,
2504 Location loc, Value sgpr1, Value sgpr2,
2505 ArrayRef<Value> consts) const {
2506 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2507 OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
2508 Value tensorDim0;
2509 if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult))
2510 tensorDim0 =
2511 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2512 else
2513 tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
2514
2515 Value c16 = createI32Constant(rewriter, loc, 16);
2516 Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
2517 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
2518 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
2519 return {sgpr1, sgpr2};
2520 }
2521
2522 std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
2523 OpAdaptor adaptor,
2524 ConversionPatternRewriter &rewriter,
2525 Location loc, Value sgpr2, Value sgpr3,
2526 ArrayRef<Value> consts) const {
2527 // TODO: Generalize to setTensorDimX.
2528 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2529 OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
2530 Value tensorDim1;
2531 if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult))
2532 tensorDim1 =
2533 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2534 else
2535 tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
2536
2537 Value c16 = createI32Constant(rewriter, loc, 16);
2538 Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
2539 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
2540 sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
2541 return {sgpr2, sgpr3};
2542 }
2543
2544 Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2545 ConversionPatternRewriter &rewriter, Location loc,
2546 Value sgpr, ArrayRef<Value> consts, size_t dimX,
2547 int64_t offset) const {
2548 SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
2549
2550 if (mixedSharedSizes.size() <= dimX)
2551 return sgpr;
2552
2553 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2554 Value tileDimX;
2555 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
2556 tileDimX =
2557 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2558 else
2559 tileDimX = cast<Value>(tileDimXOpFoldResult);
2560
2561 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2562 }
2563
2564 Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2565 ConversionPatternRewriter &rewriter, Location loc,
2566 Value sgpr3, ArrayRef<Value> consts) const {
2567 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2568 }
2569
2570 Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2571 ConversionPatternRewriter &rewriter, Location loc,
2572 Value sgpr4, ArrayRef<Value> consts) const {
2573 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2574 }
2575
2576 Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2577 ConversionPatternRewriter &rewriter, Location loc,
2578 Value sgpr4, ArrayRef<Value> consts) const {
2579 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2580 }
2581
2582 std::pair<Value, Value>
2583 setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2584 ConversionPatternRewriter &rewriter, Location loc,
2585 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2586 size_t dimX, int64_t offset) const {
2587 SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
2588
2589 if (mixedGlobalStrides.size() <= dimX)
2590 return {sgprY, sgprZ};
2591
2592 OpFoldResult tensorDimXStrideOpFoldResult =
2593 *(mixedGlobalStrides.rbegin() + dimX);
2594 Value tensorDimXStride;
2595 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
2596 tensorDimXStride =
2597 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2598 else
2599 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
2600
2601 constexpr int64_t first48bits = (1ll << 48) - 1;
2602 Value mask = createI64Constant(rewriter, loc, first48bits);
2603 tensorDimXStride =
2604 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
2605 IntegerType i32 = rewriter.getI32Type();
2606 Value tensorDimXStrideLow =
2607 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
2608
2609 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
2610 Value shiftVal = createI64Constant(rewriter, loc, shift);
2611 Value tensorDimXStrideHigh =
2612 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
2613 tensorDimXStrideHigh =
2614 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
2615
2616 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
2617 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
2618 offset + shift);
2619 return {sgprY, sgprZ};
2620 }
2621
2622 std::pair<Value, Value>
2623 setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2624 ConversionPatternRewriter &rewriter, Location loc,
2625 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
2626 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2627 0, 160);
2628 }
2629
2630 std::pair<Value, Value>
2631 setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2632 ConversionPatternRewriter &rewriter, Location loc,
2633 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
2634 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2635 1, 208);
2636 }
2637
2638 Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2639 ConversionPatternRewriter &rewriter, Location loc,
2640 ArrayRef<Value> consts) const {
2641 Value sgprs[8];
2642 for (int64_t i = 0; i < 8; i++) {
2643 sgprs[i] = consts[0];
2644 }
2645
2646 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
2647 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
2648 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
2649 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2650 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2651 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
2652 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
2653 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
2654
2655 sgprs[1] =
2656 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
2657 std::tie(sgprs[1], sgprs[2]) =
2658 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
2659 std::tie(sgprs[2], sgprs[3]) =
2660 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
2661
2662 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
2663 sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
2664 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
2665 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
2666 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
2667 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
2668 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
2669
2670 IntegerType i32 = rewriter.getI32Type();
2671 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
2672 assert(v8i32 && "expected type conversion to succeed");
2673 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
2674
2675 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
2676 dgroup1 =
2677 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
2678 }
2679
2680 return dgroup1;
2681 }
2682
2683 LogicalResult
2684 matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2685 ConversionPatternRewriter &rewriter) const override {
2686 if (chipset < kGfx1250)
2687 return op->emitOpError(
2688 "make_dma_descriptor is only supported on gfx1250");
2689
2690 if (op.getRank() > 2)
2691 return op->emitOpError("unimplemented");
2692
2693 Location loc = op.getLoc();
2694
2695 IntegerType i32 = rewriter.getI32Type();
2696 [[maybe_unused]] Type v4i32 =
2697 this->typeConverter->convertType(VectorType::get(4, i32));
2698 assert(v4i32 && "expected type conversion to succeed");
2699
2700 SmallVector<Value> consts;
2701 for (int64_t i = 0; i < 8; i++)
2702 consts.push_back(createI32Constant(rewriter, loc, i));
2703
2704 Value dgroup0 = this->getDGroup0(adaptor);
2705 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
2706
2707 SmallVector<Value> results = {dgroup0, dgroup1};
2708 rewriter.replaceOpWithMultiple(op, {results});
2709 return success();
2710 }
2711};
2712
2713struct ConvertAMDGPUToROCDLPass
2714 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
2715 using Base::Base;
2716
2717 void runOnOperation() override {
2718 MLIRContext *ctx = &getContext();
2719 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
2720 if (failed(maybeChipset)) {
2721 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
2722 return signalPassFailure();
2723 }
2724
2725 RewritePatternSet patterns(ctx);
2726 LLVMTypeConverter converter(ctx);
2727 converter.addConversion([&](TDMBaseType type) -> Type {
2728 Type i32 = IntegerType::get(type.getContext(), 32);
2729 return converter.convertType(VectorType::get(4, i32));
2730 });
2731
2732 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
2733 LLVMConversionTarget target(getContext());
2734 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
2735 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
2736 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
2737 if (failed(applyPartialConversion(getOperation(), target,
2738 std::move(patterns))))
2739 signalPassFailure();
2740 }
2741};
2742} // namespace
2743
2745 TypeConverter &typeConverter) {
2746 typeConverter.addTypeAttributeConversion(
2747 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
2748 -> TypeConverter::AttributeConversionResult {
2749 MLIRContext *ctx = as.getContext();
2750 Type i64 = IntegerType::get(ctx, 64);
2751 switch (as.getValue()) {
2752 case amdgpu::AddressSpace::FatRawBuffer:
2753 return IntegerAttr::get(i64, 7);
2754 case amdgpu::AddressSpace::BufferRsrc:
2755 return IntegerAttr::get(i64, 8);
2756 case amdgpu::AddressSpace::FatStructuredBuffer:
2757 return IntegerAttr::get(i64, 9);
2758 }
2759 return TypeConverter::AttributeConversionResult::abort();
2760 });
2761}
2762
2765 Chipset chipset) {
2767 patterns.add<
2768 FatRawBufferCastLowering,
2769 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2770 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2771 RawBufferOpLowering<RawBufferAtomicFaddOp,
2772 ROCDL::RawPtrBufferAtomicFaddOp>,
2773 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2774 ROCDL::RawPtrBufferAtomicFmaxOp>,
2775 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2776 ROCDL::RawPtrBufferAtomicSmaxOp>,
2777 RawBufferOpLowering<RawBufferAtomicUminOp,
2778 ROCDL::RawPtrBufferAtomicUminOp>,
2779 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2780 ROCDL::RawPtrBufferAtomicCmpSwap>,
2781 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2782 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2783 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2784 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2785 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2786 GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2787 AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
2788 chipset);
2789 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
2790}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
constexpr Chipset kGfx90a
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:432
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:393
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14