MLIR 22.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Matchers.h"
24#include "mlir/Pass/Pass.h"
25
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Casting.h"
31#include "llvm/Support/ErrorHandling.h"
32#include <optional>
33
34namespace mlir {
35#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40using namespace mlir::amdgpu;
41
42// Define commonly used chipsets versions for convenience.
43constexpr Chipset kGfx908 = Chipset(9, 0, 8);
44constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
45constexpr Chipset kGfx942 = Chipset(9, 4, 2);
46constexpr Chipset kGfx950 = Chipset(9, 5, 0);
47constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
48
49/// Convert an unsigned number `val` to i32.
50static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
51 Location loc, Value val) {
52 IntegerType i32 = rewriter.getI32Type();
53 // Force check that `val` is of int type.
54 auto valTy = cast<IntegerType>(val.getType());
55 if (i32 == valTy)
56 return val;
57 return valTy.getWidth() > 32
58 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
59 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
60}
61
62static Value createI32Constant(ConversionPatternRewriter &rewriter,
63 Location loc, int32_t value) {
64 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
65}
66
67/// Convert an unsigned number `val` to i64.
68static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
69 Location loc, Value val) {
70 IntegerType i64 = rewriter.getI64Type();
71 // Force check that `val` is of int type.
72 auto valTy = cast<IntegerType>(val.getType());
73 if (i64 == valTy)
74 return val;
75 return valTy.getWidth() > 64
76 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
77 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
78}
79
80static Value createI64Constant(ConversionPatternRewriter &rewriter,
81 Location loc, int64_t value) {
82 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
83}
84
85/// Returns the linear index used to access an element in the memref.
86static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
87 Location loc, MemRefDescriptor &memRefDescriptor,
89 IntegerType i32 = rewriter.getI32Type();
91 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
92 if (stride != 1) { // Skip if stride is 1.
93 Value strideValue =
94 ShapedType::isDynamic(stride)
95 ? convertUnsignedToI32(rewriter, loc,
96 memRefDescriptor.stride(rewriter, loc, i))
97 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
98 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
99 }
100 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
101 : increment;
102 }
103 return index ? index : createI32Constant(rewriter, loc, 0);
104}
105
106/// Compute the contents of the `num_records` field for a given memref
107/// descriptor - that is, the number of bytes that's one element past the
108/// greatest possible valid index into the memref.
109static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
110 MemRefType memrefType,
111 MemRefDescriptor &memrefDescriptor,
112 ArrayRef<int64_t> strides,
113 int64_t elementByteWidth) {
114 if (memrefType.hasStaticShape() &&
115 !llvm::any_of(strides, ShapedType::isDynamic)) {
116 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
117 ArrayRef<int64_t> shape = memrefType.getShape();
118 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
119 size = std::max(shape[i] * strides[i], size);
120 size = size * elementByteWidth;
121 return createI64Constant(rewriter, loc, size);
122 }
123 Value maxIndex;
124 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
125 Value size = memrefDescriptor.size(rewriter, loc, i);
126 Value stride = memrefDescriptor.stride(rewriter, loc, i);
127 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
128 maxIndex = maxIndex
129 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
130 : maxThisDim;
131 }
132 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
133 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
134 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
135}
136
137static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
138 Value basePointer, Value numRecords,
139 bool boundsCheck, amdgpu::Chipset chipset,
140 Value cacheSwizzleStride = nullptr,
141 unsigned addressSpace = 8) {
142 // The stride value is generally 0. However, on MI-300 and onward, you can
143 // enable a cache swizzling mode by setting bit 14 of the stride field
144 // and setting that stride to a cache stride.
145 Type i16 = rewriter.getI16Type();
146 Value stride;
147 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
148 Value cacheStrideZext =
149 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
150 Value swizzleBit = LLVM::ConstantOp::create(
151 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
152 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
153 /*isDisjoint=*/true);
154 } else {
155 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
156 rewriter.getI16IntegerAttr(0));
157 }
158 // Get the number of elements.
159 // Flag word:
160 // bits 0-11: dst sel, ignored by these intrinsics
161 // bits 12-14: data format (ignored, must be nonzero, 7=float)
162 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
163 // bit 19: In nested heap (0 here)
164 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
165 // bits 21-22: Index stride for swizzles (N/A)
166 // bit 23: Add thread ID (0)
167 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
168 // bits 25-26: Reserved (0)
169 // bit 27: Buffer is non-volatile (CDNA only)
170 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
171 // none, 3 = either swizzles or testing against offset field) RDNA only
172 // bits 30-31: Type (must be 0)
173 uint32_t flags = (7 << 12) | (4 << 15);
174 if (chipset.majorVersion >= 10) {
175 flags |= (1 << 24);
176 uint32_t oob = boundsCheck ? 3 : 2;
177 flags |= (oob << 28);
178 }
179 Value flagsConst = createI32Constant(rewriter, loc, flags);
180 Type rsrcType =
181 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
182 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
183 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
184 return resource;
185}
186
187namespace {
188struct FatRawBufferCastLowering
189 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
190 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
191 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
192 chipset(chipset) {}
193
194 Chipset chipset;
195
196 LogicalResult
197 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter) const override {
199 Location loc = op.getLoc();
200 Value memRef = adaptor.getSource();
201 Value unconvertedMemref = op.getSource();
202 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
203 MemRefDescriptor descriptor(memRef);
204
205 DataLayout dataLayout = DataLayout::closest(op);
206 int64_t elementByteWidth =
207 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
208
209 int64_t unusedOffset = 0;
210 SmallVector<int64_t, 5> strideVals;
211 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
212 return op.emitOpError("Can't lower non-stride-offset memrefs");
213
214 Value numRecords = adaptor.getValidBytes();
215 if (!numRecords)
216 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
217 strideVals, elementByteWidth);
218
219 Value basePointer =
220 adaptor.getResetOffset()
221 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
222 memrefType)
223 : descriptor.alignedPtr(rewriter, loc);
224
225 Value offset = adaptor.getResetOffset()
226 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
227 rewriter.getIndexAttr(0))
228 : descriptor.offset(rewriter, loc);
229
230 bool hasSizes = memrefType.getRank() > 0;
231 // No need to unpack() and pack() all the individual sizes and strides,
232 // so we'll just extract the arrays.
233 Value sizes = hasSizes
234 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
236 : Value{};
237 Value strides =
238 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
240 : Value{};
241
242 Value fatPtr = makeBufferRsrc(
243 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
244 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
245
246 Value result = MemRefDescriptor::poison(
247 rewriter, loc,
248 getTypeConverter()->convertType(op.getResult().getType()));
249 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
250 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
251 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
253 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
255 if (hasSizes) {
256 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
258 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
260 }
261 rewriter.replaceOp(op, result);
262 return success();
263 }
264};
265
266/// Define lowering patterns for raw buffer ops
267template <typename GpuOp, typename Intrinsic>
268struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
269 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
270 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
271
272 Chipset chipset;
273 static constexpr uint32_t maxVectorOpWidth = 128;
274
275 LogicalResult
276 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 Location loc = gpuOp.getLoc();
279 Value memref = adaptor.getMemref();
280 Value unconvertedMemref = gpuOp.getMemref();
281 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
282
283 if (chipset.majorVersion < 9)
284 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
285
286 Value storeData = adaptor.getODSOperands(0)[0];
287 if (storeData == memref) // no write component to this op
288 storeData = Value();
289 Type wantedDataType;
290 if (storeData)
291 wantedDataType = storeData.getType();
292 else
293 wantedDataType = gpuOp.getODSResults(0)[0].getType();
294
295 Value atomicCmpData = Value();
296 // Operand index 1 of a load is the indices, trying to read them can crash.
297 if (storeData) {
298 Value maybeCmpData = adaptor.getODSOperands(1)[0];
299 if (maybeCmpData != memref)
300 atomicCmpData = maybeCmpData;
301 }
302
303 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
304
305 Type i32 = rewriter.getI32Type();
306
307 // Get the type size in bytes.
308 DataLayout dataLayout = DataLayout::closest(gpuOp);
309 int64_t elementByteWidth =
310 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
311 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
312
313 // If we want to load a vector<NxT> with total size <= 32
314 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
315 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
316 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
317 // so bitcast any floats to integers.
318 Type llvmBufferValType = llvmWantedDataType;
319 if (atomicCmpData) {
320 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
321 llvmBufferValType = this->getTypeConverter()->convertType(
322 rewriter.getIntegerType(floatType.getWidth()));
323 }
324 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
325 uint32_t vecLen = dataVector.getNumElements();
326 uint32_t elemBits =
327 dataLayout.getTypeSizeInBits(dataVector.getElementType());
328 uint32_t totalBits = elemBits * vecLen;
329 bool usePackedFp16 =
330 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
331 if (totalBits > maxVectorOpWidth)
332 return gpuOp.emitOpError(
333 "Total width of loads or stores must be no more than " +
334 Twine(maxVectorOpWidth) + " bits, but we call for " +
335 Twine(totalBits) +
336 " bits. This should've been caught in validation");
337 if (!usePackedFp16 && elemBits < 32) {
338 if (totalBits > 32) {
339 if (totalBits % 32 != 0)
340 return gpuOp.emitOpError("Load or store of more than 32-bits that "
341 "doesn't fit into words. Can't happen\n");
342 llvmBufferValType = this->typeConverter->convertType(
343 VectorType::get(totalBits / 32, i32));
344 } else {
345 llvmBufferValType = this->typeConverter->convertType(
346 rewriter.getIntegerType(totalBits));
347 }
348 }
349 }
350 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
351 // Buffer intrinsics doesn't support 1-element vectors, cast them to
352 // scalars.
353 if (vecType.getNumElements() == 1)
354 llvmBufferValType = vecType.getElementType();
355 }
356
357 SmallVector<Value, 6> args;
358 if (storeData) {
359 if (llvmBufferValType != llvmWantedDataType) {
360 Value castForStore = LLVM::BitcastOp::create(
361 rewriter, loc, llvmBufferValType, storeData);
362 args.push_back(castForStore);
363 } else {
364 args.push_back(storeData);
365 }
366 }
367
368 if (atomicCmpData) {
369 if (llvmBufferValType != llvmWantedDataType) {
370 Value castForCmp = LLVM::BitcastOp::create(
371 rewriter, loc, llvmBufferValType, atomicCmpData);
372 args.push_back(castForCmp);
373 } else {
374 args.push_back(atomicCmpData);
375 }
376 }
377
378 // Construct buffer descriptor from memref, attributes
379 int64_t offset = 0;
380 SmallVector<int64_t, 5> strides;
381 if (failed(memrefType.getStridesAndOffset(strides, offset)))
382 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
383
384 MemRefDescriptor memrefDescriptor(memref);
385
386 Value ptr = memrefDescriptor.bufferPtr(
387 rewriter, loc, *this->getTypeConverter(), memrefType);
388 Value numRecords = getNumRecords(
389 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
390 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
391 adaptor.getBoundsCheck(), chipset);
392 args.push_back(resource);
393
394 // Indexing (voffset)
395 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
396 adaptor.getIndices(), strides);
397 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
398 indexOffset && *indexOffset > 0) {
399 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
400 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
401 extraOffsetConst)
402 : extraOffsetConst;
403 }
404 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
405 args.push_back(voffset);
406
407 // SGPR offset.
408 Value sgprOffset = adaptor.getSgprOffset();
409 if (!sgprOffset)
410 sgprOffset = createI32Constant(rewriter, loc, 0);
411 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
412 args.push_back(sgprOffset);
413
414 // bit 0: GLC = 0 (atomics drop value, less coherency)
415 // bits 1-2: SLC, DLC = 0 (similarly)
416 // bit 3: swizzled (0 for raw)
417 args.push_back(createI32Constant(rewriter, loc, 0));
418
419 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
420 llvmBufferValType);
421 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
422 ArrayRef<NamedAttribute>());
423 if (lowered->getNumResults() == 1) {
424 Value replacement = lowered->getResult(0);
425 if (llvmBufferValType != llvmWantedDataType) {
426 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
428 }
429 rewriter.replaceOp(gpuOp, replacement);
430 } else {
431 rewriter.eraseOp(gpuOp);
432 }
433 return success();
434 }
435};
436
437// TODO: AMDGPU backend already have all this bitpacking logic, we should move
438// it to some common place.
439/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
440/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
441/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
442/// Vmcnt = Waitcnt[15:10] (gfx11)
443/// Expcnt = Waitcnt[6:4] (pre-gfx11)
444/// Expcnt = Waitcnt[2:0] (gfx11)
445/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
446/// Lgkmcnt = Waitcnt[13:8] (gfx10)
447/// Lgkmcnt = Waitcnt[9:4] (gfx11)
448static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
449 unsigned expcnt, unsigned lgkmcnt) {
450 if (chipset.majorVersion < 9) {
451 vmcnt = std::min(15u, vmcnt);
452 expcnt = std::min(7u, expcnt);
453 lgkmcnt = std::min(15u, lgkmcnt);
454 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
455 }
456 if (chipset.majorVersion == 9) {
457 vmcnt = std::min(63u, vmcnt);
458 expcnt = std::min(7u, expcnt);
459 lgkmcnt = std::min(15u, lgkmcnt);
460 unsigned lowBits = vmcnt & 0xF;
461 unsigned highBits = (vmcnt >> 4) << 14;
462 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
463 return lowBits | highBits | otherCnts;
464 }
465 if (chipset.majorVersion == 10) {
466 vmcnt = std::min(63u, vmcnt);
467 expcnt = std::min(7u, expcnt);
468 lgkmcnt = std::min(63u, lgkmcnt);
469 unsigned lowBits = vmcnt & 0xF;
470 unsigned highBits = (vmcnt >> 4) << 14;
471 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
472 return lowBits | highBits | otherCnts;
474 if (chipset.majorVersion == 11) {
475 vmcnt = std::min(63u, vmcnt);
476 expcnt = std::min(7u, expcnt);
477 lgkmcnt = std::min(63u, lgkmcnt);
478 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
479 }
480 return failure();
482
483struct MemoryCounterWaitOpLowering
484 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
485 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
487 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
488 chipset(chipset) {}
489
490 Chipset chipset;
491
492 LogicalResult
493 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const override {
495 if (chipset.majorVersion >= 12) {
496 Location loc = op.getLoc();
497 if (std::optional<int> ds = adaptor.getDs())
498 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
499
500 if (std::optional<int> load = adaptor.getLoad())
501 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
502
503 if (std::optional<int> store = adaptor.getStore())
504 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
505
506 if (std::optional<int> exp = adaptor.getExp())
507 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
508
509 rewriter.eraseOp(op);
510 return success();
512
513 auto getVal = [](Attribute attr) -> unsigned {
514 if (attr)
515 return cast<IntegerAttr>(attr).getInt();
516
517 // This value will be clamped to the maximum value for the chipset.
518 return 1024;
519 };
520 unsigned ds = getVal(adaptor.getDsAttr());
521 unsigned exp = getVal(adaptor.getExpAttr());
522
523 unsigned vmcnt = 1024;
524 Attribute load = adaptor.getLoadAttr();
525 Attribute store = adaptor.getStoreAttr();
526 if (load && store) {
527 vmcnt = getVal(load) + getVal(store);
528 } else if (load) {
529 vmcnt = getVal(load);
530 } else if (store) {
531 vmcnt = getVal(store);
532 }
533
534 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
535 if (failed(waitcnt))
536 return op.emitOpError("unsupported chipset");
537
538 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
539 return success();
540 }
541};
542
543struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
544 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
545 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
546
547 Chipset chipset;
548
549 LogicalResult
550 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
551 ConversionPatternRewriter &rewriter) const override {
552 Location loc = op.getLoc();
553 // This ensures that waits on global memory aren't introduced on
554 // chips that don't have the BackOffBarrier feature enabled in LLVM.
555 bool requiresInlineAsm = chipset < kGfx90a;
556
557 Attribute mmra =
558 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
559 // Note: while there *is* a workgroup-one-as scope, this, when combined with
560 // the MMRA, will lead to the fence having no effect. This is because the
561 // codepaths for an atomic load or store will observe that a
562 // one-address-space atomic to LDS requires no synchronization because
563 // operations on LDS are totally ordered with respect to each other, and so
564 // will not emit the correct waitcnt operations that these fences are
565 // intended to produce. Therefore, we use a broader type of fence and rely
566 // on the MMRA to relax it to the semantics we want.
567 StringRef scope = "workgroup";
568
569 auto relFence = LLVM::FenceOp::create(rewriter, loc,
570 LLVM::AtomicOrdering::release, scope);
571 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
572 if (requiresInlineAsm) {
573 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
574 LLVM::AsmDialect::AD_ATT);
575 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
576 const char *constraints = "";
577 LLVM::InlineAsmOp::create(
578 rewriter, loc,
579 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
580 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
581 /*is_align_stack=*/false, LLVM::TailCallKind::None,
582 /*asm_dialect=*/asmDialectAttr,
583 /*operand_attrs=*/ArrayAttr());
584 } else if (chipset.majorVersion < 12) {
585 ROCDL::SBarrierOp::create(rewriter, loc);
586 } else {
587 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
588 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
589 }
590
591 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
592 LLVM::AtomicOrdering::acquire, scope);
593 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
594 rewriter.replaceOp(op, acqFence);
595 return success();
596 }
597};
598
599struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
600 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
601 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
602
603 Chipset chipset;
604
605 LogicalResult
606 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
607 ConversionPatternRewriter &rewriter) const override {
608 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
609 (uint32_t)op.getOpts());
610 return success();
611 }
612};
613
614} // namespace
615
616/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
617/// and LLVM AMDGPU intrinsics convention.
618///
619/// Specifically:
620/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
621/// allows bf16. Newer MFMAs support bf16 types on operand, check
622/// IntrinsicsAMDGPU.td file for reference.
623/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
624/// instead, which is what the f8f6f4 intrinsics use.
625/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
626/// integer.
627///
628/// Note that the type of `input` has already been LLVM type converted:
629/// therefore 8-bit and smaller floats are represented as their corresponding
630/// `iN` integers.
631static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
632 Location loc, Value input,
633 bool allowBf16 = true) {
634 Type inputType = input.getType();
635 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
636 if (vectorType.getElementType().isBF16() && !allowBf16)
637 return LLVM::BitcastOp::create(
638 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
639 if (vectorType.getElementType().isInteger(8) &&
640 vectorType.getNumElements() <= 8)
641 return LLVM::BitcastOp::create(
642 rewriter, loc,
643 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
644 if (isa<IntegerType>(vectorType.getElementType()) &&
645 vectorType.getElementTypeBitWidth() <= 8) {
646 int64_t numWords = llvm::divideCeil(
647 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
648 32);
649 return LLVM::BitcastOp::create(
650 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
651 input);
652 }
653 }
654 return input;
655}
656
657/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
658/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
659///
660/// Specifically:
661/// 1. If `input` is a i8 value, zero extend it to i32
662/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
663///
664/// Note that the type of `input` has already been LLVM type converted:
665/// therefore 8-bit and smaller floats are represented as their corresponding
666/// `iN` integers.
667static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
668 Location loc, Value input) {
669 Type inputType = input.getType();
670 Type outputType = rewriter.getI32Type();
671 if (auto intType = dyn_cast<IntegerType>(inputType))
672 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
673 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
674}
675
676/// Push an input operand. If it is a float type, nothing to do. If it is
677/// an integer type, then we need to also push its signdness (1 for signed, 0
678/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
679/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
680/// We also need to convert bfloat inputs to i16 to account for the bfloat
681/// intrinsics having been defined before the AMD backend supported bfloat. We
682/// similarly need to pack 8-bit float types into integers as if they were i8
683/// (which they are for the backend's purposes).
685 ConversionPatternRewriter &rewriter, Location loc,
686 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
687 Value mlirInput, SmallVectorImpl<Value> &operands,
688 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
689 Type inputType = llvmInput.getType();
690 auto vectorType = dyn_cast<VectorType>(inputType);
691 if (!vectorType) {
692 operands.push_back(llvmInput);
693 return;
694 }
695 Type elemType = vectorType.getElementType();
696 if (elemType.getIntOrFloatBitWidth() > 8) {
697 operands.push_back(llvmInput);
698 return;
699 }
700
701 // We need to check the type of the input before conversion to properly test
702 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
703 // fp8/int8 information is lost during the conversion process.
704 auto mlirInputType = cast<VectorType>(mlirInput.getType());
705 bool isInputInteger = mlirInputType.getElementType().isInteger();
706 if (isInputInteger) {
707 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
708 bool localIsUnsigned = isUnsigned;
709 if (elemType.isUnsignedInteger()) {
710 localIsUnsigned = true;
711 } else if (elemType.isSignedInteger()) {
712 localIsUnsigned = false;
713 }
714 attrs.push_back(
715 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
716 }
717
718 int64_t numBits =
719 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
720 Type i32 = rewriter.getI32Type();
721 Type intrinsicInType = numBits <= 32
722 ? (Type)rewriter.getIntegerType(numBits)
723 : (Type)VectorType::get(numBits / 32, i32);
724 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
725 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
726 loc, llvmIntrinsicInType, llvmInput);
727 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
728 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
729 // Add in the zeros here.
730 if (numBits < 32)
731 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
732 operands.push_back(castInput);
733}
734
735/// Push the output operand. For many cases this is only pushing the output in
736/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
737/// since the same numbers of VGPRs is used, we need to decide if to store the
738/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
739/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
740/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
741/// as the instructions have been changed to return fewer registers instead.
742static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
743 Location loc,
744 const TypeConverter *typeConverter,
745 Value output, int32_t subwordOffset,
746 bool clamp, SmallVectorImpl<Value> &operands,
748 Type inputType = output.getType();
749 auto vectorType = dyn_cast<VectorType>(inputType);
750 Type elemType = vectorType.getElementType();
751 operands.push_back(output);
752 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
753 attrs.push_back(
754 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
755 } else if (elemType.isInteger(32)) {
756 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
757 }
758}
759
760/// Return true if `type` is the E5M2 variant of an 8-bit float that is
761/// supported by the `_bf8` instructions on the given `chipset`.
762static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
763 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
764 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
765}
766
767/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
768/// supported by the `_fp8` instructions on the given `chipset`.
769static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
770 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
771 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
772}
773
774/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
775/// if one exists. This includes checking to ensure the intrinsic is supported
776/// on the architecture you are compiling for.
777static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
778 Chipset chipset) {
779 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
780 b = mfma.getBlocks();
781 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
782 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
783
784 if (sourceElem.isF32() && destElem.isF32()) {
785 if (mfma.getReducePrecision() && chipset >= kGfx942) {
786 if (m == 32 && n == 32 && k == 4 && b == 1)
787 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
788 if (m == 16 && n == 16 && k == 8 && b == 1)
789 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
790 }
791 if (m == 32 && n == 32 && k == 1 && b == 2)
792 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
793 if (m == 16 && n == 16 && k == 1 && b == 4)
794 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
795 if (m == 4 && n == 4 && k == 1 && b == 16)
796 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
797 if (m == 32 && n == 32 && k == 2 && b == 1)
798 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
799 if (m == 16 && n == 16 && k == 4 && b == 1)
800 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
801 }
802
803 if (sourceElem.isF16() && destElem.isF32()) {
804 if (chipset >= kGfx950) {
805 if (m == 32 && n == 32 && k == 16 && b == 1)
806 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
807 if (m == 16 && n == 16 && k == 32 && b == 1)
808 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
809 }
810 if (m == 32 && n == 32 && k == 4 && b == 2)
811 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
812 if (m == 16 && n == 16 && k == 4 && b == 4)
813 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
814 if (m == 4 && n == 4 && k == 4 && b == 16)
815 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
816 if (m == 32 && n == 32 && k == 8 && b == 1)
817 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
818 if (m == 16 && n == 16 && k == 16 && b == 1)
819 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
820 }
821
822 if (sourceElem.isBF16() && destElem.isF32()) {
823 if (chipset >= kGfx950) {
824 if (m == 32 && n == 32 && k == 16 && b == 1)
825 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
826 if (m == 16 && n == 16 && k == 32 && b == 1)
827 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
828 }
829 if (chipset >= kGfx90a) {
830 if (m == 32 && n == 32 && k == 4 && b == 2)
831 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
832 if (m == 16 && n == 16 && k == 4 && b == 4)
833 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
834 if (m == 4 && n == 4 && k == 4 && b == 16)
835 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
836 if (m == 32 && n == 32 && k == 8 && b == 1)
837 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
838 if (m == 16 && n == 16 && k == 16 && b == 1)
839 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
840 }
841 if (m == 32 && n == 32 && k == 2 && b == 2)
842 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
843 if (m == 16 && n == 16 && k == 2 && b == 4)
844 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
845 if (m == 4 && n == 4 && k == 2 && b == 16)
846 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
847 if (m == 32 && n == 32 && k == 4 && b == 1)
848 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
849 if (m == 16 && n == 16 && k == 8 && b == 1)
850 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
851 }
852
853 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
854 if (chipset >= kGfx950) {
855 if (m == 32 && n == 32 && k == 32 && b == 1)
856 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
857 if (m == 16 && n == 16 && k == 64 && b == 1)
858 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
859 }
860 if (m == 32 && n == 32 && k == 4 && b == 2)
861 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
862 if (m == 16 && n == 16 && k == 4 && b == 4)
863 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
864 if (m == 4 && n == 4 && k == 4 && b == 16)
865 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
866 if (m == 32 && n == 32 && k == 8 && b == 1)
867 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
868 if (m == 16 && n == 16 && k == 16 && b == 1)
869 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
870 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
871 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
872 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
873 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
874 }
875
876 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
877 if (m == 16 && n == 16 && k == 4 && b == 1)
878 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
879 if (m == 4 && n == 4 && k == 4 && b == 4)
880 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
881 }
882
883 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
884 // Known to be correct because there are no scalar f8 instructions and
885 // because a length mismatch will have been caught by the verifier.
886 Type sourceBElem =
887 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
888 if (m == 16 && n == 16 && k == 32 && b == 1) {
889 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
890 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
891 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
892 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
893 }
894 if (m == 32 && n == 32 && k == 16 && b == 1) {
895 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
896 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
897 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
898 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
899 }
900 }
901
902 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
903 Type sourceBElem =
904 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
905 if (m == 16 && n == 16 && k == 32 && b == 1) {
906 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
907 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
908 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
909 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
910 }
911 if (m == 32 && n == 32 && k == 16 && b == 1) {
912 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
913 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
914 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
915 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
916 }
917 }
918
919 return std::nullopt;
920}
921
922static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
924 .Case([](Float8E4M3FNType) { return 0u; })
925 .Case([](Float8E5M2Type) { return 1u; })
926 .Case([](Float6E2M3FNType) { return 2u; })
927 .Case([](Float6E3M2FNType) { return 3u; })
928 .Case([](Float4E2M1FNType) { return 4u; })
929 .Default(std::nullopt);
930}
931
932/// If there is a scaled MFMA instruction for the input element types `aType`
933/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
934/// blocks) on the given `chipset`, return a tuple consisting of the
935/// OperationName of the intrinsic and the type codes that need to be passed to
936/// that intrinsic. Note that this is also used to implement some un-scaled
937/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
938/// MFMA with a scale of 0.
939static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
940mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
941 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
942 aType = getElementTypeOrSelf(aType);
943 bType = getElementTypeOrSelf(bType);
944 destType = getElementTypeOrSelf(destType);
945
946 if (chipset < kGfx950)
947 return std::nullopt;
948 if (!isa<Float32Type>(destType))
949 return std::nullopt;
950
951 std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
952 std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
953 if (!aTypeCode || !bTypeCode)
954 return std::nullopt;
955
956 if (m == 32 && n == 32 && k == 64 && b == 1)
957 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
958 *aTypeCode, *bTypeCode};
959 if (m == 16 && n == 16 && k == 128 && b == 1)
960 return std::tuple{
961 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
962 *bTypeCode};
963
964 return std::nullopt;
965}
966
967static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
968mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
970 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
971 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
972 mfma.getBlocks(), chipset);
973}
974
975static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
976mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
977 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
978 smfma.getSourceB().getType(),
979 smfma.getDestC().getType(), smfma.getM(),
980 smfma.getN(), smfma.getK(), 1u, chipset);
981}
982
983/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
984/// for RDNA3/4 architectures.
985static std::optional<StringRef>
986wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
987 Type elemDestType, uint32_t k, bool isRDNA3) {
988 using fp8 = Float8E4M3FNType;
989 using bf8 = Float8E5M2Type;
990
991 // Handle k == 16 for RDNA3/4.
992 if (k == 16) {
993 // Common patterns for RDNA3 and RDNA4.
994 if (elemSourceType.isF16() && elemDestType.isF32())
995 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
996 if (elemSourceType.isBF16() && elemDestType.isF32())
997 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
998 if (elemSourceType.isF16() && elemDestType.isF16())
999 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1000 if (elemSourceType.isBF16() && elemDestType.isBF16())
1001 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1002 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1003 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1004
1005 // RDNA3 specific patterns.
1006 if (isRDNA3) {
1007 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1008 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1009 return std::nullopt;
1010 }
1011
1012 // RDNA4 specific patterns (fp8/bf8).
1013 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1014 elemDestType.isF32())
1015 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1016 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1017 elemDestType.isF32())
1018 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1019 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1020 elemDestType.isF32())
1021 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1022 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1023 elemDestType.isF32())
1024 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1025 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1026 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1027
1028 return std::nullopt;
1029 }
1030
1031 // Handle k == 32 for RDNA4.
1032 if (k == 32 && !isRDNA3) {
1033 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1034 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1035 }
1036
1037 return std::nullopt;
1038}
1039
1040/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1041/// for the gfx1250 architecture.
1042static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1043 Type elemBSourceType,
1044 Type elemDestType,
1045 uint32_t k) {
1046 using fp8 = Float8E4M3FNType;
1047 using bf8 = Float8E5M2Type;
1048
1049 if (k == 4) {
1050 if (elemSourceType.isF32() && elemDestType.isF32())
1051 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1052
1053 return std::nullopt;
1054 }
1055
1056 if (k == 32) {
1057 if (elemSourceType.isF16() && elemDestType.isF32())
1058 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1059 if (elemSourceType.isBF16() && elemDestType.isF32())
1060 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1061 if (elemSourceType.isF16() && elemDestType.isF16())
1062 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1063 if (elemSourceType.isBF16() && elemDestType.isBF16())
1064 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1065
1066 return std::nullopt;
1067 }
1068
1069 if (k == 64) {
1070 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1071 if (elemDestType.isF32())
1072 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1073 if (elemDestType.isF16())
1074 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1075 }
1076 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1077 if (elemDestType.isF32())
1078 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1079 if (elemDestType.isF16())
1080 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1081 }
1082 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1083 if (elemDestType.isF32())
1084 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1085 if (elemDestType.isF16())
1086 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1087 }
1088 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1089 if (elemDestType.isF32())
1090 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1091 if (elemDestType.isF16())
1092 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1093 }
1094 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1095 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1096
1097 return std::nullopt;
1098 }
1099
1100 if (k == 128) {
1101 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1102 if (elemDestType.isF32())
1103 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1104 if (elemDestType.isF16())
1105 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1106 }
1107 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1108 if (elemDestType.isF32())
1109 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1110 if (elemDestType.isF16())
1111 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1112 }
1113 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1114 if (elemDestType.isF32())
1115 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1116 if (elemDestType.isF16())
1117 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1118 }
1119 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1120 if (elemDestType.isF32())
1121 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1122 if (elemDestType.isF16())
1123 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1124 }
1125
1126 return std::nullopt;
1127 }
1128
1129 return std::nullopt;
1130}
1131
1132/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1133/// if one exists. This includes checking to ensure the intrinsic is supported
1134/// on the architecture you are compiling for.
1135static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1136 Chipset chipset) {
1137 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1138 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1139 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1140 Type elemSourceType = sourceVectorType.getElementType();
1141 Type elemBSourceType = sourceBVectorType.getElementType();
1142 Type elemDestType = destVectorType.getElementType();
1143
1144 const uint32_t k = wmma.getK();
1145 const bool isRDNA3 = chipset.majorVersion == 11;
1146 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1147
1148 // Handle RDNA3 and RDNA4.
1149 if (isRDNA3 || isRDNA4)
1150 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1151 k, isRDNA3);
1152
1153 // Handle gfx1250.
1154 if (chipset == kGfx1250)
1155 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1156 elemDestType, k);
1157
1158 return std::nullopt;
1159}
1160
1161namespace {
1162struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1163 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1164 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1165
1166 Chipset chipset;
1167
1168 LogicalResult
1169 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1170 ConversionPatternRewriter &rewriter) const override {
1171 Location loc = op.getLoc();
1172 Type outType = typeConverter->convertType(op.getDestD().getType());
1173 Type intrinsicOutType = outType;
1174 if (auto outVecType = dyn_cast<VectorType>(outType))
1175 if (outVecType.getElementType().isBF16())
1176 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1177
1178 if (chipset.majorVersion != 9 || chipset < kGfx908)
1179 return op->emitOpError("MFMA only supported on gfx908+");
1180 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1181 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1182 if (chipset < kGfx942)
1183 return op.emitOpError("negation unsupported on older than gfx942");
1184 getBlgpField |=
1185 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1186 }
1187 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1188 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1189 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1190 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1191 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1192
1193 bool isScaled =
1194 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1195 if (isScaled &&
1196 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1197 return op.emitOpError(
1198 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1199 "be scaled as those fields are used for type information");
1200 }
1201
1202 StringRef intrinsicName =
1203 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1204 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1205 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1206 bool allowBf16 = [&]() {
1207 if (chipset < kGfx950)
1208 return false;
1209 if (isScaled)
1210 return true;
1211 return intrinsicName.contains("16x16x32.bf16") ||
1212 intrinsicName.contains("32x32x16.bf16");
1213 }();
1214 OperationState loweredOp(loc, intrinsicName);
1215 loweredOp.addTypes(intrinsicOutType);
1216 loweredOp.addOperands({convertMFMAVectorOperand(
1217 rewriter, loc, adaptor.getSourceA(), allowBf16),
1219 rewriter, loc, adaptor.getSourceB(), allowBf16),
1220 adaptor.getDestC()});
1221 if (isScaled) {
1222 Value zero = createI32Constant(rewriter, loc, 0);
1223 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1224 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1225 createI32Constant(rewriter, loc, bTypeCode),
1226 /*scale A byte=*/zero, /*scale A=*/zero,
1227 /*scale B byte=*/zero, /*scale B=*/zero});
1228 } else {
1229 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1230 createI32Constant(rewriter, loc, op.getAbid()),
1231 createI32Constant(rewriter, loc, getBlgpField)});
1232 };
1233 Value lowered = rewriter.create(loweredOp)->getResult(0);
1234 if (outType != intrinsicOutType)
1235 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1236 rewriter.replaceOp(op, lowered);
1237 return success();
1238 }
1239};
1240
1241struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1242 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1243 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1244
1245 Chipset chipset;
1246
1247 LogicalResult
1248 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1249 ConversionPatternRewriter &rewriter) const override {
1250 Location loc = op.getLoc();
1251 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1252
1253 if (chipset.majorVersion != 9 || chipset < kGfx950)
1254 return op->emitOpError("scaled MFMA only supported on gfx908+");
1255 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1256 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1257 if (!maybeScaledIntrinsic.has_value())
1258 return op.emitOpError(
1259 "no intrinsic matching scaled MFMA size on given chipset");
1260
1261 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1262 OperationState loweredOp(loc, intrinsicName);
1263 loweredOp.addTypes(intrinsicOutType);
1264 loweredOp.addOperands(
1265 {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1266 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1267 adaptor.getDestC()});
1268 Value scalesIdxA =
1269 createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1270 Value scalesIdxB =
1271 createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1272 loweredOp.addOperands(
1273 {createI32Constant(rewriter, loc, aTypeCode),
1274 createI32Constant(rewriter, loc, bTypeCode),
1275 /*scales idx A=*/scalesIdxA,
1276 /*scales A*/
1277 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1278 /*scales idx B=*/scalesIdxB,
1279 /*scales B*/
1280 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1281 Value lowered = rewriter.create(loweredOp)->getResult(0);
1282 rewriter.replaceOp(op, lowered);
1283 return success();
1284 }
1285};
1286
1287struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1288 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1289 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1290
1291 Chipset chipset;
1292
1293 LogicalResult
1294 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1295 ConversionPatternRewriter &rewriter) const override {
1296 Location loc = op.getLoc();
1297 auto outType =
1298 typeConverter->convertType<VectorType>(op.getDestD().getType());
1299 if (!outType)
1300 return rewriter.notifyMatchFailure(op, "type conversion failed");
1301
1302 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1303 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1304
1305 bool isGFX1250 = chipset >= kGfx1250;
1306
1307 // The WMMA operations represent vectors of bf16s as vectors of i16s
1308 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1309 // bitcast them back.
1310 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1311 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1312 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1313 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1314 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1315 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1316 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1317 VectorType rawOutType = outType;
1318 if (castOutToI16)
1319 rawOutType = outType.clone(rewriter.getI16Type());
1320 Value a = adaptor.getSourceA();
1321 if (castAToI16)
1322 a = LLVM::BitcastOp::create(rewriter, loc,
1323 aType.clone(rewriter.getI16Type()), a);
1324 Value b = adaptor.getSourceB();
1325 if (castBToI16)
1326 b = LLVM::BitcastOp::create(rewriter, loc,
1327 bType.clone(rewriter.getI16Type()), b);
1328 Value destC = adaptor.getDestC();
1329 if (castDestCToI16)
1330 destC = LLVM::BitcastOp::create(
1331 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1332
1333 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1334
1335 if (!maybeIntrinsic.has_value())
1336 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1337
1338 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1339 return op.emitOpError("subwordOffset not supported on gfx12+");
1340
1341 SmallVector<Value, 4> operands;
1342 SmallVector<NamedAttribute, 4> attrs;
1343 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1344 op.getSourceA(), operands, attrs, "signA");
1345 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1346 op.getSourceB(), operands, attrs, "signB");
1347 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1348 op.getSubwordOffset(), op.getClamp(), operands,
1349 attrs);
1350
1351 OperationState loweredOp(loc, *maybeIntrinsic);
1352 loweredOp.addTypes(rawOutType);
1353 loweredOp.addOperands(operands);
1354 loweredOp.addAttributes(attrs);
1355 Operation *lowered = rewriter.create(loweredOp);
1356
1357 Operation *maybeCastBack = lowered;
1358 if (rawOutType != outType)
1359 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1360 lowered->getResult(0));
1361 rewriter.replaceOp(op, maybeCastBack->getResults());
1362
1363 return success();
1364 }
1365};
1366
1367struct TransposeLoadOpLowering
1368 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1369 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1370 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1371
1372 Chipset chipset;
1373
1374 LogicalResult
1375 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1376 ConversionPatternRewriter &rewriter) const override {
1377 if (chipset != kGfx950)
1378 return op.emitOpError("Non-gfx950 chipset not supported");
1379
1380 Location loc = op.getLoc();
1381 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1382
1383 // Elements in subbyte memrefs are stored non-contiguously,
1384 // reject if source is sub-byte memref. Use emulated memrefs instead.
1385 size_t srcElementSize =
1386 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1387 if (srcElementSize < 8)
1388 return op.emitOpError("Expect source memref to have at least 8 bits "
1389 "element size, got ")
1390 << srcElementSize;
1391
1392 auto resultType = cast<VectorType>(op.getResult().getType());
1393 Value srcPtr =
1394 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1395 (adaptor.getSrcIndices()));
1396
1397 size_t numElements = resultType.getNumElements();
1398 size_t elementTypeSize =
1399 resultType.getElementType().getIntOrFloatBitWidth();
1400
1401 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1402 // the element size is smaller than 16 bits.
1403 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1404 rewriter.getIntegerType(32));
1405 Type llvmResultType = typeConverter->convertType(resultType);
1406
1407 switch (elementTypeSize) {
1408 case 4: {
1409 assert(numElements == 16);
1410 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1411 rocdlResultType, srcPtr);
1412 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1413 break;
1414 }
1415 case 6: {
1416 assert(numElements == 16);
1417 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1418 rocdlResultType, srcPtr);
1419 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1420 break;
1421 }
1422 case 8: {
1423 assert(numElements == 8);
1424 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1425 rocdlResultType, srcPtr);
1426 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1427 break;
1428 }
1429 case 16: {
1430 assert(numElements == 4);
1431 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1432 srcPtr);
1433 break;
1434 }
1435 default:
1436 return op.emitOpError("Unsupported element size for transpose load");
1437 }
1438 return success();
1439 }
1440};
1441
1442struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1443 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1444 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1445
1446 Chipset chipset;
1447
1448 LogicalResult
1449 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1450 ConversionPatternRewriter &rewriter) const override {
1451 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1452 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1453
1454 Location loc = op.getLoc();
1455
1456 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1457 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1458
1459 // TODO: instead of only transfering one element per thread, we could
1460 // augment it to transfer multiple elements per thread by issuing multiple
1461 // `global_load_lds` instructions.
1462 Type transferType = op.getTransferType();
1463 int loadWidth = [&]() -> int {
1464 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1465 return (transferVectorType.getNumElements() *
1466 transferVectorType.getElementTypeBitWidth()) /
1467 8;
1468 }
1469 return transferType.getIntOrFloatBitWidth() / 8;
1470 }();
1471
1472 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1473 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1474 return op.emitOpError("chipset unsupported element size");
1475
1476 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1477 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1478 "16-byte load widths are only supported on gfx950");
1479
1480 Value srcPtr =
1481 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1482 (adaptor.getSrcIndices()));
1483 Value dstPtr =
1484 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1485 (adaptor.getDstIndices()));
1486
1487 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1488 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1489 /*offset=*/rewriter.getI32IntegerAttr(0),
1490 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1491 ArrayAttr{});
1492
1493 return success();
1494 }
1495};
1496
1497namespace {
1498struct ExtPackedFp8OpLowering final
1499 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1500 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1501 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1502 chipset(chipset) {}
1503 Chipset chipset;
1504
1505 LogicalResult
1506 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1507 ConversionPatternRewriter &rewriter) const override;
1508};
1509
1510struct ScaledExtPackedMatrixOpLowering final
1511 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1512 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1513 Chipset chipset)
1514 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1515 chipset(chipset) {}
1516 Chipset chipset;
1517
1518 LogicalResult
1519 matchAndRewrite(ScaledExtPackedMatrixOp op,
1520 ScaledExtPackedMatrixOpAdaptor adaptor,
1521 ConversionPatternRewriter &rewriter) const override;
1522};
1523
1524struct PackedTrunc2xFp8OpLowering final
1525 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1526 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1527 Chipset chipset)
1528 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1529 chipset(chipset) {}
1530 Chipset chipset;
1531
1532 LogicalResult
1533 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1534 ConversionPatternRewriter &rewriter) const override;
1535};
1536
1537struct PackedStochRoundFp8OpLowering final
1538 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1539 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1540 Chipset chipset)
1541 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1542 chipset(chipset) {}
1543 Chipset chipset;
1544
1545 LogicalResult
1546 matchAndRewrite(PackedStochRoundFp8Op op,
1547 PackedStochRoundFp8OpAdaptor adaptor,
1548 ConversionPatternRewriter &rewriter) const override;
1549};
1550
1551struct ScaledExtPackedOpLowering final
1552 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1553 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1554 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1555 chipset(chipset) {}
1556 Chipset chipset;
1557
1558 LogicalResult
1559 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1560 ConversionPatternRewriter &rewriter) const override;
1561};
1562
1563struct PackedScaledTruncOpLowering final
1564 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1565 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1566 Chipset chipset)
1567 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1568 chipset(chipset) {}
1569 Chipset chipset;
1570
1571 LogicalResult
1572 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1573 ConversionPatternRewriter &rewriter) const override;
1574};
1575
1576} // end namespace
1577
1578LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1579 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1580 ConversionPatternRewriter &rewriter) const {
1581 Location loc = op.getLoc();
1582 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1583 return rewriter.notifyMatchFailure(
1584 loc, "Fp8 conversion instructions are not available on target "
1585 "architecture and their emulation is not implemented");
1586 Type v4i8 =
1587 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1588 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1589 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1590
1591 Value source = adaptor.getSource();
1592 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1593 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1594 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1595 // Extend to a v4i8
1596 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1597 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1598 if (!sourceVecType) {
1599 longVec = LLVM::InsertElementOp::create(
1600 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1601 } else {
1602 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1603 Value idx = createI32Constant(rewriter, loc, i);
1604 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1605 longVec =
1606 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1607 }
1608 }
1609 source = longVec;
1610 }
1611 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1612 if (resultVecType) {
1613 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1614 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1615 op.getIndex());
1616 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1617 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1618 op.getIndex());
1619 }
1620 } else {
1621 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1622 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1623 op.getIndex());
1624 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1625 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1626 op.getIndex());
1627 }
1628 }
1629 return success();
1630}
1631
1632int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1633 int32_t firstScaleByte) {
1634 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1635 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
1636 // firstScaleByte are merged into a single attribute scaleSel. This is how
1637 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1638 // attribute but is derifed from firstScaleLane).
1639 assert(llvm::is_contained({16, 32}, blockSize));
1640 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1641
1642 const bool isFp8 = bitWidth == 8;
1643 const bool isBlock16 = blockSize == 16;
1644
1645 if (!isFp8) {
1646 int32_t bit0 = isBlock16;
1647 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1648 int32_t bit1 = (firstScaleByte == 2) << 1;
1649 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1650 int32_t bit2 = scaleWaveHalf << 2;
1651 return bit2 | bit1 | bit0;
1652 }
1653
1654 int32_t bit0 = isBlock16;
1655 // firstScaleByte is guaranteed to be defined by two bits.
1656 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1657 int32_t bits2and1 = firstScaleByte << 1;
1658 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1659 int32_t bit3 = scaleWaveHalf << 3;
1660 int32_t bits = bit3 | bits2and1 | bit0;
1661 // These are invalid cases.
1662 assert(!llvm::is_contained(
1663 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1664 return bits;
1665}
1666
1667static std::optional<StringRef>
1668scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1669 using fp4 = Float4E2M1FNType;
1670 using fp8 = Float8E4M3FNType;
1671 using bf8 = Float8E5M2Type;
1672 using fp6 = Float6E2M3FNType;
1673 using bf6 = Float6E3M2FNType;
1674 if (isa<fp4>(srcElemType)) {
1675 if (destElemType.isF16())
1676 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1677 if (destElemType.isBF16())
1678 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1679 if (destElemType.isF32())
1680 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1681 return std::nullopt;
1682 }
1683 if (isa<fp8>(srcElemType)) {
1684 if (destElemType.isF16())
1685 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1686 if (destElemType.isBF16())
1687 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1688 if (destElemType.isF32())
1689 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1690 return std::nullopt;
1691 }
1692 if (isa<bf8>(srcElemType)) {
1693 if (destElemType.isF16())
1694 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1695 if (destElemType.isBF16())
1696 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1697 if (destElemType.isF32())
1698 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1699 return std::nullopt;
1700 }
1701 if (isa<fp6>(srcElemType)) {
1702 if (destElemType.isF16())
1703 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1704 if (destElemType.isBF16())
1705 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1706 if (destElemType.isF32())
1707 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1708 return std::nullopt;
1709 }
1710 if (isa<bf6>(srcElemType)) {
1711 if (destElemType.isF16())
1712 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1713 if (destElemType.isBF16())
1714 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1715 if (destElemType.isF32())
1716 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1717 return std::nullopt;
1718 }
1719 llvm_unreachable("invalid combination of element types for packed conversion "
1720 "instructions");
1721}
1722
1723LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
1724 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
1725 ConversionPatternRewriter &rewriter) const {
1726 using fp4 = Float4E2M1FNType;
1727 using fp8 = Float8E4M3FNType;
1728 using bf8 = Float8E5M2Type;
1729 using fp6 = Float6E2M3FNType;
1730 using bf6 = Float6E3M2FNType;
1731 Location loc = op.getLoc();
1732 if (chipset != kGfx1250) {
1733 return rewriter.notifyMatchFailure(
1734 loc,
1735 "Scaled fp packed conversion instructions are not available on target "
1736 "architecture and their emulation is not implemented");
1737 }
1738 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
1739 // is being selected.
1740 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
1741 int32_t firstScaleByte = op.getFirstScaleByte();
1742 int32_t blockSize = op.getBlockSize();
1743 auto sourceType = cast<VectorType>(op.getSource().getType());
1744 auto srcElemType = cast<FloatType>(sourceType.getElementType());
1745 unsigned bitWidth = srcElemType.getWidth();
1746
1747 auto targetType = cast<VectorType>(op.getResult().getType());
1748 auto destElemType = cast<FloatType>(targetType.getElementType());
1749
1750 IntegerType i32 = rewriter.getI32Type();
1751 Value source = adaptor.getSource();
1752 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
1753 Type packedType = nullptr;
1754 if (isa<fp4>(srcElemType)) {
1755 packedType = i32;
1756 packedType = getTypeConverter()->convertType(packedType);
1757 } else if (isa<fp8, bf8>(srcElemType)) {
1758 packedType = VectorType::get(2, i32);
1759 packedType = getTypeConverter()->convertType(packedType);
1760 } else if (isa<fp6, bf6>(srcElemType)) {
1761 packedType = VectorType::get(3, i32);
1762 packedType = getTypeConverter()->convertType(packedType);
1763 } else {
1764 llvm_unreachable("invalid element type for packed scaled ext");
1765 }
1766
1767 if (!packedType || !llvmResultType) {
1768 return rewriter.notifyMatchFailure(op, "type conversion failed");
1769 }
1770
1771 std::optional<StringRef> maybeIntrinsic =
1772 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
1773 if (!maybeIntrinsic.has_value())
1774 return op.emitOpError(
1775 "no intrinsic matching packed scaled conversion on the given chipset");
1776
1777 int32_t scaleSel =
1778 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
1779 Value castedScale =
1780 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
1781 Value castedSource =
1782 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
1783
1784 OperationState loweredOp(loc, *maybeIntrinsic);
1785 loweredOp.addTypes({llvmResultType});
1786 loweredOp.addOperands({castedSource, castedScale});
1787
1788 SmallVector<NamedAttribute, 1> attrs;
1789 attrs.push_back(
1790 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
1791
1792 loweredOp.addAttributes(attrs);
1793 Operation *lowered = rewriter.create(loweredOp);
1794 rewriter.replaceOp(op, lowered);
1795
1796 return success();
1797}
1798
1799LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1800 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1801 ConversionPatternRewriter &rewriter) const {
1802 Location loc = op.getLoc();
1803 if (chipset != kGfx950)
1804 return rewriter.notifyMatchFailure(
1805 loc, "Scaled fp conversion instructions are not available on target "
1806 "architecture and their emulation is not implemented");
1807 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1808
1809 Value source = adaptor.getSource();
1810 Value scale = adaptor.getScale();
1811
1812 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1813 Type sourceElemType = sourceVecType.getElementType();
1814 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1815 Type destElemType = destVecType.getElementType();
1816
1817 VectorType packedVecType;
1818 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1819 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1820 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1821 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1822 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1823 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1824 } else {
1825 llvm_unreachable("invalid element type for scaled ext");
1826 }
1827
1828 // Extend to a packedVectorType
1829 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1830 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1831 if (!sourceVecType) {
1832 longVec = LLVM::InsertElementOp::create(
1833 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1834 } else {
1835 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1836 Value idx = createI32Constant(rewriter, loc, i);
1837 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1838 longVec =
1839 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1840 }
1841 }
1842 source = longVec;
1843 }
1844 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1845
1846 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1847 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1848 op, destVecType, i32Source, scale, op.getIndex());
1849 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1850 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1851 op, destVecType, i32Source, scale, op.getIndex());
1852 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1853 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1854 op, destVecType, i32Source, scale, op.getIndex());
1855 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1856 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1857 op, destVecType, i32Source, scale, op.getIndex());
1858 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1859 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1860 op, destVecType, i32Source, scale, op.getIndex());
1861 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1862 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1863 op, destVecType, i32Source, scale, op.getIndex());
1864 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1865 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1866 op, destVecType, i32Source, scale, op.getIndex());
1867 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1868 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1869 op, destVecType, i32Source, scale, op.getIndex());
1870 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1871 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1872 op, destVecType, i32Source, scale, op.getIndex());
1873 else
1874 return failure();
1875
1876 return success();
1877}
1878
1879LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1880 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1881 ConversionPatternRewriter &rewriter) const {
1882 Location loc = op.getLoc();
1883 if (chipset != kGfx950)
1884 return rewriter.notifyMatchFailure(
1885 loc, "Scaled fp conversion instructions are not available on target "
1886 "architecture and their emulation is not implemented");
1887 Type v2i16 = getTypeConverter()->convertType(
1888 VectorType::get(2, rewriter.getI16Type()));
1889 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1890
1891 Type resultType = op.getResult().getType();
1892 Type resultElemType = getElementTypeOrSelf(resultType);
1893 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1894 Type sourceElemType = sourceVecType.getElementType();
1895
1896 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1897
1898 Value source = adaptor.getSource();
1899 Value scale = adaptor.getScale();
1900 Value existing = adaptor.getExisting();
1901 if (existing)
1902 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1903 else
1904 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1905
1906 if (sourceVecType.getNumElements() < 2) {
1907 Value c0 = createI32Constant(rewriter, loc, 0);
1908 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1909 VectorType v2 = VectorType::get(2, sourceElemType);
1910 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1911 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1912 }
1913
1914 Value sourceA, sourceB;
1915 if (sourceElemType.isF32()) {
1916 Value c0 = createI32Constant(rewriter, loc, 0);
1917 Value c1 = createI32Constant(rewriter, loc, 1);
1918 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1919 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1920 }
1921
1922 Value result;
1923 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1924 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1925 existing, sourceA, sourceB,
1926 scale, op.getIndex());
1927 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1928 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1929 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1930 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1931 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1932 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1933 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1934 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1935 existing, sourceA, sourceB,
1936 scale, op.getIndex());
1937 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1938 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1939 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1940 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1941 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1942 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1943 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1944 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1945 existing, sourceA, sourceB,
1946 scale, op.getIndex());
1947 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1948 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1949 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1950 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1951 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1952 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1953 else
1954 return failure();
1955
1956 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1957 op, getTypeConverter()->convertType(resultType), result);
1958 return success();
1959}
1960
1961LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1962 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1963 ConversionPatternRewriter &rewriter) const {
1964 Location loc = op.getLoc();
1965 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1966 return rewriter.notifyMatchFailure(
1967 loc, "Fp8 conversion instructions are not available on target "
1968 "architecture and their emulation is not implemented");
1969 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1970
1971 Type resultType = op.getResult().getType();
1972 Type resultElemType = getElementTypeOrSelf(resultType);
1973
1974 Value sourceA = adaptor.getSourceA();
1975 Value sourceB = adaptor.getSourceB();
1976 if (!sourceB)
1977 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
1978 Value existing = adaptor.getExisting();
1979 if (existing)
1980 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1981 else
1982 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1983
1984 Value result;
1985 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1986 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1987 existing, op.getWordIndex());
1988 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1989 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1990 existing, op.getWordIndex());
1991
1992 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1993 op, getTypeConverter()->convertType(resultType), result);
1994 return success();
1995}
1996
1997LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1998 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1999 ConversionPatternRewriter &rewriter) const {
2000 Location loc = op.getLoc();
2001 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2002 return rewriter.notifyMatchFailure(
2003 loc, "Fp8 conversion instructions are not available on target "
2004 "architecture and their emulation is not implemented");
2005 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2006
2007 Type resultType = op.getResult().getType();
2008 Type resultElemType = getElementTypeOrSelf(resultType);
2009
2010 Value source = adaptor.getSource();
2011 Value stoch = adaptor.getStochiasticParam();
2012 Value existing = adaptor.getExisting();
2013 if (existing)
2014 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2015 else
2016 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2017
2018 Value result;
2019 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2020 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2021 existing, op.getStoreIndex());
2022 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2023 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2024 existing, op.getStoreIndex());
2025
2026 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2027 op, getTypeConverter()->convertType(resultType), result);
2028 return success();
2029}
2030
2031// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2032// operation into the corresponding ROCDL instructions.
2033struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2034 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2035 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2036 Chipset chipset;
2037
2038 LogicalResult
2039 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2040 ConversionPatternRewriter &rewriter) const override {
2041
2042 // Convert the source operand to the corresponding LLVM type
2043 Location loc = DppOp.getLoc();
2044 Value src = adaptor.getSrc();
2045 Value old = adaptor.getOld();
2046 Type srcType = src.getType();
2047 Type oldType = old.getType();
2048 Type llvmType = nullptr;
2049 if (srcType.getIntOrFloatBitWidth() < 32) {
2050 llvmType = rewriter.getI32Type();
2051 } else if (isa<FloatType>(srcType)) {
2052 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2053 ? rewriter.getF32Type()
2054 : rewriter.getF64Type();
2055 } else if (isa<IntegerType>(srcType)) {
2056 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2057 ? rewriter.getI32Type()
2058 : rewriter.getI64Type();
2059 }
2060 auto llvmSrcIntType = typeConverter->convertType(
2061 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2062
2063 // If the source type is less of 32, use bitcast to convert it to i32.
2064 auto convertOperand = [&](Value operand, Type operandType) {
2065 if (operandType.getIntOrFloatBitWidth() <= 16) {
2066 if (llvm::isa<FloatType>(operandType)) {
2067 operand =
2068 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2069 }
2070 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2071 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2072 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2073 operand =
2074 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2075 createI32Constant(rewriter, loc, 0));
2076 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2077 }
2078 return operand;
2079 };
2080
2081 src = convertOperand(src, srcType);
2082 old = convertOperand(old, oldType);
2083
2084 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2085 enum DppCtrl : unsigned {
2086 ROW_SHL0 = 0x100,
2087 ROW_SHR0 = 0x110,
2088 ROW_ROR0 = 0x120,
2089 WAVE_SHL1 = 0x130,
2090 WAVE_ROL1 = 0x134,
2091 WAVE_SHR1 = 0x138,
2092 WAVE_ROR1 = 0x13C,
2093 ROW_MIRROR = 0x140,
2094 ROW_HALF_MIRROR = 0x141,
2095 BCAST15 = 0x142,
2096 BCAST31 = 0x143,
2097 };
2098
2099 auto kind = DppOp.getKind();
2100 auto permArgument = DppOp.getPermArgument();
2101 uint32_t DppCtrl = 0;
2102
2103 switch (kind) {
2104
2105 case DPPPerm::quad_perm:
2106 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
2107 int32_t i = 0;
2108 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2109 uint32_t num = elem.getInt();
2110 DppCtrl |= num << (i * 2);
2111 i++;
2112 }
2113 }
2114 break;
2115 case DPPPerm::row_shl:
2116 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
2117 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2118 }
2119 break;
2120 case DPPPerm::row_shr:
2121 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
2122 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2123 }
2124 break;
2125 case DPPPerm::row_ror:
2126 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
2127 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2128 }
2129 break;
2130 case DPPPerm::wave_shl:
2131 DppCtrl = DppCtrl::WAVE_SHL1;
2132 break;
2133 case DPPPerm::wave_shr:
2134 DppCtrl = DppCtrl::WAVE_SHR1;
2135 break;
2136 case DPPPerm::wave_rol:
2137 DppCtrl = DppCtrl::WAVE_ROL1;
2138 break;
2139 case DPPPerm::wave_ror:
2140 DppCtrl = DppCtrl::WAVE_ROR1;
2141 break;
2142 case DPPPerm::row_mirror:
2143 DppCtrl = DppCtrl::ROW_MIRROR;
2144 break;
2145 case DPPPerm::row_half_mirror:
2146 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2147 break;
2148 case DPPPerm::row_bcast_15:
2149 DppCtrl = DppCtrl::BCAST15;
2150 break;
2151 case DPPPerm::row_bcast_31:
2152 DppCtrl = DppCtrl::BCAST31;
2153 break;
2154 }
2155
2156 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2157 // constants
2158 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2159 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2160 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2161
2162 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2163 auto dppMovOp =
2164 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2165 rowMask, bankMask, boundCtrl);
2166
2167 Value result = dppMovOp.getRes();
2168 if (srcType.getIntOrFloatBitWidth() < 32) {
2169 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2170 if (!llvm::isa<IntegerType>(srcType)) {
2171 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2172 }
2173 }
2174
2175 // We are replacing the AMDGPU_DPPOp instruction with the new
2176 // ROCDL_DPPMovOp instruction
2177 rewriter.replaceOp(DppOp, ValueRange(result));
2178 return success();
2179 }
2180};
2181
2182struct AMDGPUSwizzleBitModeLowering
2183 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2185
2186 LogicalResult
2187 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2188 ConversionPatternRewriter &rewriter) const override {
2189 Location loc = op.getLoc();
2190 Type i32 = rewriter.getI32Type();
2191 Value src = adaptor.getSrc();
2192 SmallVector<Value> decomposed =
2193 LLVM::decomposeValue(rewriter, loc, src, i32);
2194 unsigned andMask = op.getAndMask();
2195 unsigned orMask = op.getOrMask();
2196 unsigned xorMask = op.getXorMask();
2197
2198 // bit 15 is 0 for the BitMode swizzle.
2199 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2200 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2201 Value maskValue = createI32Constant(rewriter, loc, mask);
2202 SmallVector<Value> swizzled;
2203 for (Value v : decomposed) {
2204 Value res =
2205 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2206 swizzled.emplace_back(res);
2207 }
2208
2209 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2210 rewriter.replaceOp(op, result);
2211 return success();
2212 }
2213};
2214
2215struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2217
2218 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2219 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2220 Chipset chipset;
2221
2222 LogicalResult
2223 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2224 ConversionPatternRewriter &rewriter) const override {
2225 if (chipset < kGfx950)
2226 return op->emitOpError("permlane_swap is only supported on gfx950+");
2227
2228 Location loc = op.getLoc();
2229 Type i32 = rewriter.getI32Type();
2230 Value src = adaptor.getSrc();
2231 unsigned rowLength = op.getRowLength();
2232 bool fi = op.getFetchInactive();
2233 bool boundctrl = op.getBoundCtrl();
2234
2235 SmallVector<Value> decomposed =
2236 LLVM::decomposeValue(rewriter, loc, src, i32);
2237
2238 SmallVector<Value> permuted;
2239 for (Value v : decomposed) {
2240 Value res;
2241 Type i32pair = LLVM::LLVMStructType::getLiteral(
2242 rewriter.getContext(), {v.getType(), v.getType()});
2243
2244 if (rowLength == 16)
2245 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2246 boundctrl);
2247 else if (rowLength == 32)
2248 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2249 boundctrl);
2250 else
2251 llvm_unreachable("unsupported row length");
2252
2253 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2254 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2255
2256 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2257 LLVM::ICmpPredicate::eq, vdst0, v);
2258
2259 // Per `permlane(16|32)` semantics: if the first extracted element equals
2260 // 'v', the result is the second element; otherwise it is the first.
2261 Value vdstNew =
2262 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2263 permuted.emplace_back(vdstNew);
2264 }
2265
2266 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2267 rewriter.replaceOp(op, result);
2268 return success();
2269 }
2270};
2271
2272struct AMDGPUMakeDmaBaseLowering
2273 : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2275
2276 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2277 : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
2278 Chipset chipset;
2279
2280 LogicalResult
2281 matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
2282 ConversionPatternRewriter &rewriter) const override {
2283 if (chipset < kGfx1250)
2284 return op->emitOpError("make_dma_base is only supported on gfx1250");
2285
2286 Location loc = op.getLoc();
2287
2288 ValueRange ldsIndices = adaptor.getLdsIndices();
2289 Value lds = adaptor.getLds();
2290 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2291
2292 Value ldsPtr =
2293 getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
2294
2295 ValueRange globalIndices = adaptor.getGlobalIndices();
2296 Value global = adaptor.getGlobal();
2297 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2298
2299 Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
2300 global, globalIndices);
2301
2302 Type i32 = rewriter.getI32Type();
2303 Type i64 = rewriter.getI64Type();
2304
2305 Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2306 Value castForGlobalAddr =
2307 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2308
2309 Value lowHalf =
2310 LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2311
2312 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2313 createI64Constant(rewriter, loc, 32));
2314
2315 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2316
2317 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2318 Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2319
2320 Value typeField = createI32Constant(rewriter, loc, 2 << 30);
2321 Value highHalfPlusType =
2322 LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
2323
2324 Value c0 = createI32Constant(rewriter, loc, 0);
2325 Value c1 = createI32Constant(rewriter, loc, 1);
2326 Value c2 = createI32Constant(rewriter, loc, 2);
2327 Value c3 = createI32Constant(rewriter, loc, 3);
2328
2329 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2330 assert(v4i32 && "expected type conversion to succeed");
2331 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2332 result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
2333 result = LLVM::InsertElementOp::create(rewriter, loc, result,
2334 castForLdsAddr, c1);
2335 result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
2336 result = LLVM::InsertElementOp::create(rewriter, loc, result,
2337 highHalfPlusType, c3);
2338
2339 rewriter.replaceOp(op, result);
2340 return success();
2341 }
2342};
2343
2344struct AMDGPUMakeDmaDescriptorLowering
2345 : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
2347
2348 AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
2349 Chipset chipset)
2350 : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
2351 chipset(chipset) {}
2352 Chipset chipset;
2353
2354 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
2355
2356 Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2357 Value accumulator, Value value, int64_t shift) const {
2358 shift = shift % 32;
2359 Value shiftAmount;
2360 if (shift != 0) {
2361 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
2362 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2363 }
2364
2365 if (matchPattern(accumulator, mlir::m_Zero()))
2366 return value;
2367
2368 return LLVM::OrOp::create(rewriter, loc, accumulator, value);
2369 }
2370
2371 Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2372 ConversionPatternRewriter &rewriter, Location loc,
2373 Value sgpr0, ArrayRef<Value> consts) const {
2374 // Compute data_size.
2375 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2376 assert(
2377 llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
2378 "expected type width to be 8, 16, 32, or 64.");
2379 int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8);
2380 return createI32Constant(rewriter, loc, dataSize << 16);
2381 }
2382
2383 Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2384 ConversionPatternRewriter &rewriter, Location loc,
2385 Value sgpr0, ArrayRef<Value> consts) const {
2386 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
2387 if (!atomic_barrier_enable)
2388 return sgpr0;
2389
2390 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2391 }
2392
2393 Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2394 ConversionPatternRewriter &rewriter, Location loc,
2395 Value sgpr0, ArrayRef<Value> consts) const {
2396 bool iterate_enable = adaptor.getGlobalIncrement() != nullptr;
2397 if (!iterate_enable)
2398 return sgpr0;
2399
2400 // TODO: In future PR, add other required fields for iteration.
2401 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2402 }
2403
2404 Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2405 ConversionPatternRewriter &rewriter, Location loc,
2406 Value sgpr0, ArrayRef<Value> consts) const {
2407 bool pad_enable = op.getPadAmount() != nullptr;
2408 if (!pad_enable)
2409 return sgpr0;
2410
2411 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2412 }
2413
2414 Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2415 ConversionPatternRewriter &rewriter, Location loc,
2416 Value sgpr0, ArrayRef<Value> consts) const {
2417 bool pad_enable = op.getPadAmount() != nullptr;
2418 if (!pad_enable)
2419 return sgpr0;
2420
2421 IntegerType i32 = rewriter.getI32Type();
2422 Value padInterval = adaptor.getPadInterval();
2423 // pre-condition: padInterval can be a power of two between 2 and 256.
2424 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2425 padInterval, false);
2426 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2427 // post-condition: padInterval can be a value between 0 and 7.
2428 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2429 }
2430
2431 Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2432 ConversionPatternRewriter &rewriter, Location loc,
2433 Value sgpr0, ArrayRef<Value> consts) const {
2434 bool pad_enable = op.getPadAmount() != nullptr;
2435 if (!pad_enable)
2436 return sgpr0;
2437
2438 Value padAmount = adaptor.getPadAmount();
2439 // pre-condition: padAmount is a value between 1-128.
2440 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2441 // post-condition: padAmount is a value between 0-127.
2442 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2443 }
2444
2445 Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2446 ConversionPatternRewriter &rewriter,
2447 Location loc, Value sgpr1,
2448 ArrayRef<Value> consts) const {
2449 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
2450 if (!atomic_barrier_enable)
2451 return sgpr1;
2452
2453 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2454 auto barrierAddressTy =
2455 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2456 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2457 atomicBarrierAddress =
2458 getStridedElementPtr(rewriter, loc, barrierAddressTy,
2459 atomicBarrierAddress, atomicBarrierIndices);
2460 IntegerType i32 = rewriter.getI32Type();
2461 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
2462 // that the 3 LSBs are zero.
2463 atomicBarrierAddress =
2464 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2465 atomicBarrierAddress =
2466 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2467 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
2468 atomicBarrierAddress =
2469 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2470 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2471 }
2472
2473 std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
2474 OpAdaptor adaptor,
2475 ConversionPatternRewriter &rewriter,
2476 Location loc, Value sgpr1, Value sgpr2,
2477 ArrayRef<Value> consts) const {
2478 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2479 OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
2480 Value tensorDim0;
2481 if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult))
2482 tensorDim0 =
2483 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2484 else
2485 tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
2486
2487 Value c16 = createI32Constant(rewriter, loc, 16);
2488 Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
2489 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
2490 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
2491 return {sgpr1, sgpr2};
2492 }
2493
2494 std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
2495 OpAdaptor adaptor,
2496 ConversionPatternRewriter &rewriter,
2497 Location loc, Value sgpr2, Value sgpr3,
2498 ArrayRef<Value> consts) const {
2499 // TODO: Generalize to setTensorDimX.
2500 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2501 OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
2502 Value tensorDim1;
2503 if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult))
2504 tensorDim1 =
2505 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2506 else
2507 tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
2508
2509 Value c16 = createI32Constant(rewriter, loc, 16);
2510 Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
2511 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
2512 sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
2513 return {sgpr2, sgpr3};
2514 }
2515
2516 Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2517 ConversionPatternRewriter &rewriter, Location loc,
2518 Value sgpr, ArrayRef<Value> consts, size_t dimX,
2519 int64_t offset) const {
2520 SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
2521
2522 if (mixedSharedSizes.size() <= dimX)
2523 return sgpr;
2524
2525 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2526 Value tileDimX;
2527 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
2528 tileDimX =
2529 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2530 else
2531 tileDimX = cast<Value>(tileDimXOpFoldResult);
2532
2533 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2534 }
2535
2536 Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2537 ConversionPatternRewriter &rewriter, Location loc,
2538 Value sgpr3, ArrayRef<Value> consts) const {
2539 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2540 }
2541
2542 Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2543 ConversionPatternRewriter &rewriter, Location loc,
2544 Value sgpr4, ArrayRef<Value> consts) const {
2545 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2546 }
2547
2548 Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2549 ConversionPatternRewriter &rewriter, Location loc,
2550 Value sgpr4, ArrayRef<Value> consts) const {
2551 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2552 }
2553
2554 std::pair<Value, Value>
2555 setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2556 ConversionPatternRewriter &rewriter, Location loc,
2557 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2558 size_t dimX, int64_t offset) const {
2559 SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
2560
2561 if (mixedGlobalStrides.size() <= dimX)
2562 return {sgprY, sgprZ};
2563
2564 OpFoldResult tensorDimXStrideOpFoldResult =
2565 *(mixedGlobalStrides.rbegin() + dimX);
2566 Value tensorDimXStride;
2567 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
2568 tensorDimXStride =
2569 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
2570 else
2571 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
2572
2573 constexpr int64_t first48bits = (1ll << 48) - 1;
2574 Value mask = createI64Constant(rewriter, loc, first48bits);
2575 tensorDimXStride =
2576 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
2577 IntegerType i32 = rewriter.getI32Type();
2578 Value tensorDimXStrideLow =
2579 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
2580
2581 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
2582 Value shiftVal = createI64Constant(rewriter, loc, shift);
2583 Value tensorDimXStrideHigh =
2584 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
2585 tensorDimXStrideHigh =
2586 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
2587
2588 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
2589 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
2590 offset + shift);
2591 return {sgprY, sgprZ};
2592 }
2593
2594 std::pair<Value, Value>
2595 setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2596 ConversionPatternRewriter &rewriter, Location loc,
2597 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
2598 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2599 0, 160);
2600 }
2601
2602 std::pair<Value, Value>
2603 setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2604 ConversionPatternRewriter &rewriter, Location loc,
2605 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
2606 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2607 1, 208);
2608 }
2609
2610 Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2611 ConversionPatternRewriter &rewriter, Location loc,
2612 ArrayRef<Value> consts) const {
2613 Value sgprs[8];
2614 for (int64_t i = 0; i < 8; i++) {
2615 sgprs[i] = consts[0];
2616 }
2617
2618 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
2619 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
2620 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2621 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2622 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
2623 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
2624
2625 sgprs[1] =
2626 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
2627 std::tie(sgprs[1], sgprs[2]) =
2628 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
2629 std::tie(sgprs[2], sgprs[3]) =
2630 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
2631
2632 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
2633 sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
2634 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
2635 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
2636 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
2637 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
2638 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
2639
2640 IntegerType i32 = rewriter.getI32Type();
2641 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
2642 assert(v8i32 && "expected type conversion to succeed");
2643 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
2644
2645 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
2646 dgroup1 =
2647 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
2648 }
2649
2650 return dgroup1;
2651 }
2652
2653 LogicalResult
2654 matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2655 ConversionPatternRewriter &rewriter) const override {
2656 if (chipset < kGfx1250)
2657 return op->emitOpError(
2658 "make_dma_descriptor is only supported on gfx1250");
2659
2660 if (op.getRank() > 2)
2661 return op->emitOpError("unimplemented");
2662
2663 Location loc = op.getLoc();
2664
2665 IntegerType i32 = rewriter.getI32Type();
2666 [[maybe_unused]] Type v4i32 =
2667 this->typeConverter->convertType(VectorType::get(4, i32));
2668 assert(v4i32 && "expected type conversion to succeed");
2669
2670 SmallVector<Value> consts;
2671 for (int64_t i = 0; i < 8; i++)
2672 consts.push_back(createI32Constant(rewriter, loc, i));
2673
2674 Value dgroup0 = this->getDGroup0(adaptor);
2675 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
2676
2677 SmallVector<Value> results = {dgroup0, dgroup1};
2678 rewriter.replaceOpWithMultiple(op, {results});
2679 return success();
2680 }
2681};
2682
2683struct ConvertAMDGPUToROCDLPass
2684 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
2685 using Base::Base;
2686
2687 void runOnOperation() override {
2688 MLIRContext *ctx = &getContext();
2689 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
2690 if (failed(maybeChipset)) {
2691 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
2692 return signalPassFailure();
2693 }
2694
2695 RewritePatternSet patterns(ctx);
2696 LLVMTypeConverter converter(ctx);
2697 converter.addConversion([&](TDMBaseType type) -> Type {
2698 Type i32 = IntegerType::get(type.getContext(), 32);
2699 return converter.convertType(VectorType::get(4, i32));
2700 });
2701
2702 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
2703 LLVMConversionTarget target(getContext());
2704 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
2705 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
2706 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
2707 if (failed(applyPartialConversion(getOperation(), target,
2708 std::move(patterns))))
2709 signalPassFailure();
2710 }
2711};
2712} // namespace
2713
2715 TypeConverter &typeConverter) {
2716 typeConverter.addTypeAttributeConversion(
2717 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
2718 -> TypeConverter::AttributeConversionResult {
2719 MLIRContext *ctx = as.getContext();
2720 Type i64 = IntegerType::get(ctx, 64);
2721 switch (as.getValue()) {
2722 case amdgpu::AddressSpace::FatRawBuffer:
2723 return IntegerAttr::get(i64, 7);
2724 case amdgpu::AddressSpace::BufferRsrc:
2725 return IntegerAttr::get(i64, 8);
2726 case amdgpu::AddressSpace::FatStructuredBuffer:
2727 return IntegerAttr::get(i64, 9);
2728 }
2729 return TypeConverter::AttributeConversionResult::abort();
2730 });
2731}
2732
2735 Chipset chipset) {
2737 patterns.add<
2738 FatRawBufferCastLowering,
2739 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2740 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2741 RawBufferOpLowering<RawBufferAtomicFaddOp,
2742 ROCDL::RawPtrBufferAtomicFaddOp>,
2743 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2744 ROCDL::RawPtrBufferAtomicFmaxOp>,
2745 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2746 ROCDL::RawPtrBufferAtomicSmaxOp>,
2747 RawBufferOpLowering<RawBufferAtomicUminOp,
2748 ROCDL::RawPtrBufferAtomicUminOp>,
2749 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2750 ROCDL::RawPtrBufferAtomicCmpSwap>,
2751 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2752 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2753 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2754 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2755 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2756 GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2757 AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
2758 chipset);
2759 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
2760}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
constexpr Chipset kGfx90a
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:432
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:393
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14