MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Pass/Pass.h"
27
29
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
35#include <cstdint>
36#include <optional>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44using namespace mlir::amdgpu;
45
46// Define commonly used chipsets versions for convenience.
47constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49constexpr Chipset kGfx942 = Chipset(9, 4, 2);
50constexpr Chipset kGfx950 = Chipset(9, 5, 0);
51constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
52
53/// Convert an unsigned number `val` to i32.
54static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
55 Location loc, Value val) {
56 IntegerType i32 = rewriter.getI32Type();
57 // Force check that `val` is of int type.
58 auto valTy = cast<IntegerType>(val.getType());
59 if (i32 == valTy)
60 return val;
61 return valTy.getWidth() > 32
62 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
63 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
64}
65
66static Value createI32Constant(ConversionPatternRewriter &rewriter,
67 Location loc, int32_t value) {
68 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
69}
70
71/// Convert an unsigned number `val` to i64.
72static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
73 Location loc, Value val) {
74 IntegerType i64 = rewriter.getI64Type();
75 // Force check that `val` is of int type.
76 auto valTy = cast<IntegerType>(val.getType());
77 if (i64 == valTy)
78 return val;
79 return valTy.getWidth() > 64
80 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
81 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
82}
83
84static Value createI64Constant(ConversionPatternRewriter &rewriter,
85 Location loc, int64_t value) {
86 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
87}
88
89/// Returns the linear index used to access an element in the memref.
90static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
91 Location loc, MemRefDescriptor &memRefDescriptor,
93 IntegerType i32 = rewriter.getI32Type();
95 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
96 if (stride != 1) { // Skip if stride is 1.
97 Value strideValue =
98 ShapedType::isDynamic(stride)
99 ? convertUnsignedToI32(rewriter, loc,
100 memRefDescriptor.stride(rewriter, loc, i))
101 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
102 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
103 }
104 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
105 : increment;
106 }
107 return index ? index : createI32Constant(rewriter, loc, 0);
108}
109
110/// Compute the contents of the `num_records` field for a given memref
111/// descriptor - that is, the number of bytes that's one element past the
112/// greatest possible valid index into the memref.
113static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
114 MemRefType memrefType,
115 MemRefDescriptor &memrefDescriptor,
116 ArrayRef<int64_t> strides, int64_t elementByteWidth,
117 amdgpu::Chipset chipset, bool boundsCheck) {
118 if (chipset >= kGfx1250 && !boundsCheck) {
119 constexpr int64_t first45bits = (1ll << 45) - 1;
120 return createI64Constant(rewriter, loc, first45bits);
121 }
122 if (memrefType.hasStaticShape() &&
123 !llvm::any_of(strides, ShapedType::isDynamic)) {
124 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
125 ArrayRef<int64_t> shape = memrefType.getShape();
126 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
127 size = std::max(shape[i] * strides[i], size);
128 size = size * elementByteWidth;
129 return createI64Constant(rewriter, loc, size);
130 }
131 Value maxIndex;
132 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
133 Value size = memrefDescriptor.size(rewriter, loc, i);
134 Value stride = memrefDescriptor.stride(rewriter, loc, i);
135 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
136 maxIndex = maxIndex
137 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
138 : maxThisDim;
139 }
140 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
141 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
142 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
143}
144
145static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
146 Value basePointer, Value numRecords,
147 bool boundsCheck, amdgpu::Chipset chipset,
148 Value cacheSwizzleStride = nullptr,
149 unsigned addressSpace = 8) {
150 // The stride value is generally 0. However, on MI-300 and onward, you can
151 // enable a cache swizzling mode by setting bit 14 of the stride field
152 // and setting that stride to a cache stride.
153 Type i16 = rewriter.getI16Type();
154 Value stride;
155 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
156 Value cacheStrideZext =
157 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
158 Value swizzleBit = LLVM::ConstantOp::create(
159 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
160 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
161 /*isDisjoint=*/true);
162 } else {
163 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
164 rewriter.getI16IntegerAttr(0));
165 }
166
167 uint32_t flags = 0;
168 if (chipset >= kGfx1250) {
169 // Flag word:
170 // bit 0: swizzle
171 // bit 1: 0 means (total_offset + payload > numRecords)
172 // 1 means ((total_offset + payload >) numRecords) || ((offset +
173 // payload) > stride) only applied when swizzle_enable = 0. keep at
174 // zero.
175 // whether oob is done depends on numRecords.
176 // bits 2-3: Type (must be 0)
177 } else {
178 // Get the number of elements.
179 // Flag word:
180 // bits 0-11: dst sel, ignored by these intrinsics
181 // bits 12-14: data format (ignored, must be nonzero, 7=float)
182 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
183 // bit 19: In nested heap (0 here)
184 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
185 // bits 21-22: Index stride for swizzles (N/A)
186 // bit 23: Add thread ID (0)
187 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
188 // bits 25-26: Reserved (0)
189 // bit 27: Buffer is non-volatile (CDNA only)
190 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
191 // none, 3 = either swizzles or testing against offset field) RDNA only
192 // bits 30-31: Type (must be 0)
193 flags |= (7 << 12) | (4 << 15);
194 if (chipset.majorVersion >= 10) {
195 flags |= (1 << 24);
196 uint32_t oob = boundsCheck ? 3 : 2;
197 flags |= (oob << 28);
198 }
199 }
200 Value flagsConst = createI32Constant(rewriter, loc, flags);
201 Type rsrcType =
202 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
203 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
204 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
205 return resource;
206}
207
208namespace {
209struct FatRawBufferCastLowering
210 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
211 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
212 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
213 chipset(chipset) {}
214
215 Chipset chipset;
216
217 LogicalResult
218 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const override {
220 Location loc = op.getLoc();
221 Value memRef = adaptor.getSource();
222 Value unconvertedMemref = op.getSource();
223 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
224 MemRefDescriptor descriptor(memRef);
225
226 DataLayout dataLayout = DataLayout::closest(op);
227 int64_t elementByteWidth =
228 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
229
230 int64_t unusedOffset = 0;
231 SmallVector<int64_t, 5> strideVals;
232 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
233 return op.emitOpError("Can't lower non-stride-offset memrefs");
234
235 Value numRecords = adaptor.getValidBytes();
236 if (!numRecords)
237 numRecords =
238 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
239 elementByteWidth, chipset, adaptor.getBoundsCheck());
240
241 Value basePointer =
242 adaptor.getResetOffset()
243 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
244 memrefType)
245 : descriptor.alignedPtr(rewriter, loc);
246
247 Value offset = adaptor.getResetOffset()
248 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
249 rewriter.getIndexAttr(0))
250 : descriptor.offset(rewriter, loc);
251
252 bool hasSizes = memrefType.getRank() > 0;
253 // No need to unpack() and pack() all the individual sizes and strides,
254 // so we'll just extract the arrays.
255 Value sizes = hasSizes
256 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
258 : Value{};
259 Value strides =
260 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
262 : Value{};
263
264 Value fatPtr = makeBufferRsrc(
265 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
266 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
267
268 Value result = MemRefDescriptor::poison(
269 rewriter, loc,
270 getTypeConverter()->convertType(op.getResult().getType()));
271 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
272 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
273 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
275 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
277 if (hasSizes) {
278 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
280 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
282 }
283 rewriter.replaceOp(op, result);
284 return success();
285 }
286};
287
288/// Define lowering patterns for raw buffer ops
289template <typename GpuOp, typename Intrinsic>
290struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
291 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
292 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
293
294 Chipset chipset;
295 static constexpr uint32_t maxVectorOpWidth = 128;
296
297 LogicalResult
298 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override {
300 Location loc = gpuOp.getLoc();
301 Value memref = adaptor.getMemref();
302 Value unconvertedMemref = gpuOp.getMemref();
303 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
304
305 if (chipset.majorVersion < 9)
306 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
307
308 Value storeData = adaptor.getODSOperands(0)[0];
309 if (storeData == memref) // no write component to this op
310 storeData = Value();
311 Type wantedDataType;
312 if (storeData)
313 wantedDataType = storeData.getType();
314 else
315 wantedDataType = gpuOp.getODSResults(0)[0].getType();
316
317 Value atomicCmpData = Value();
318 // Operand index 1 of a load is the indices, trying to read them can crash.
319 if (storeData) {
320 Value maybeCmpData = adaptor.getODSOperands(1)[0];
321 if (maybeCmpData != memref)
322 atomicCmpData = maybeCmpData;
323 }
324
325 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
326
327 Type i32 = rewriter.getI32Type();
328
329 // Get the type size in bytes.
330 DataLayout dataLayout = DataLayout::closest(gpuOp);
331 int64_t elementByteWidth =
332 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
333 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
334
335 // If we want to load a vector<NxT> with total size <= 32
336 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
337 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
338 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
339 // so bitcast any floats to integers.
340 Type llvmBufferValType = llvmWantedDataType;
341 if (atomicCmpData) {
342 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
343 llvmBufferValType = this->getTypeConverter()->convertType(
344 rewriter.getIntegerType(floatType.getWidth()));
345 }
346 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
347 uint32_t vecLen = dataVector.getNumElements();
348 uint32_t elemBits =
349 dataLayout.getTypeSizeInBits(dataVector.getElementType());
350 uint32_t totalBits = elemBits * vecLen;
351 bool usePackedFp16 =
352 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
353 if (totalBits > maxVectorOpWidth)
354 return gpuOp.emitOpError(
355 "Total width of loads or stores must be no more than " +
356 Twine(maxVectorOpWidth) + " bits, but we call for " +
357 Twine(totalBits) +
358 " bits. This should've been caught in validation");
359 if (!usePackedFp16 && elemBits < 32) {
360 if (totalBits > 32) {
361 if (totalBits % 32 != 0)
362 return gpuOp.emitOpError("Load or store of more than 32-bits that "
363 "doesn't fit into words. Can't happen\n");
364 llvmBufferValType = this->typeConverter->convertType(
365 VectorType::get(totalBits / 32, i32));
366 } else {
367 llvmBufferValType = this->typeConverter->convertType(
368 rewriter.getIntegerType(totalBits));
369 }
370 }
371 }
372 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
373 // Buffer intrinsics doesn't support 1-element vectors, cast them to
374 // scalars.
375 if (vecType.getNumElements() == 1)
376 llvmBufferValType = vecType.getElementType();
377 }
378
379 SmallVector<Value, 6> args;
380 if (storeData) {
381 if (llvmBufferValType != llvmWantedDataType) {
382 Value castForStore = LLVM::BitcastOp::create(
383 rewriter, loc, llvmBufferValType, storeData);
384 args.push_back(castForStore);
385 } else {
386 args.push_back(storeData);
387 }
388 }
389
390 if (atomicCmpData) {
391 if (llvmBufferValType != llvmWantedDataType) {
392 Value castForCmp = LLVM::BitcastOp::create(
393 rewriter, loc, llvmBufferValType, atomicCmpData);
394 args.push_back(castForCmp);
395 } else {
396 args.push_back(atomicCmpData);
397 }
398 }
399
400 // Construct buffer descriptor from memref, attributes
401 int64_t offset = 0;
402 SmallVector<int64_t, 5> strides;
403 if (failed(memrefType.getStridesAndOffset(strides, offset)))
404 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
405
406 MemRefDescriptor memrefDescriptor(memref);
407
408 Value ptr = memrefDescriptor.bufferPtr(
409 rewriter, loc, *this->getTypeConverter(), memrefType);
410 Value numRecords =
411 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
412 elementByteWidth, chipset, adaptor.getBoundsCheck());
413 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
414 adaptor.getBoundsCheck(), chipset);
415 args.push_back(resource);
416
417 // Indexing (voffset)
418 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
419 adaptor.getIndices(), strides);
420 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
421 indexOffset && *indexOffset > 0) {
422 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
423 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
424 extraOffsetConst)
425 : extraOffsetConst;
426 }
427 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
428 args.push_back(voffset);
429
430 // SGPR offset.
431 Value sgprOffset = adaptor.getSgprOffset();
432 if (!sgprOffset)
433 sgprOffset = createI32Constant(rewriter, loc, 0);
434 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
435 args.push_back(sgprOffset);
436
437 // bit 0: GLC = 0 (atomics drop value, less coherency)
438 // bits 1-2: SLC, DLC = 0 (similarly)
439 // bit 3: swizzled (0 for raw)
440 args.push_back(createI32Constant(rewriter, loc, 0));
441
442 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
443 llvmBufferValType);
444 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
445 ArrayRef<NamedAttribute>());
446 if (lowered->getNumResults() == 1) {
447 Value replacement = lowered->getResult(0);
448 if (llvmBufferValType != llvmWantedDataType) {
449 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
451 }
452 rewriter.replaceOp(gpuOp, replacement);
453 } else {
454 rewriter.eraseOp(gpuOp);
455 }
456 return success();
457 }
458};
459
460// TODO: AMDGPU backend already have all this bitpacking logic, we should move
461// it to some common place.
462/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
463/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
464/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
465/// Vmcnt = Waitcnt[15:10] (gfx11)
466/// Expcnt = Waitcnt[6:4] (pre-gfx11)
467/// Expcnt = Waitcnt[2:0] (gfx11)
468/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
469/// Lgkmcnt = Waitcnt[13:8] (gfx10)
470/// Lgkmcnt = Waitcnt[9:4] (gfx11)
471static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
472 unsigned expcnt, unsigned lgkmcnt) {
473 if (chipset.majorVersion < 9) {
474 vmcnt = std::min(15u, vmcnt);
475 expcnt = std::min(7u, expcnt);
476 lgkmcnt = std::min(15u, lgkmcnt);
477 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
478 }
479 if (chipset.majorVersion == 9) {
480 vmcnt = std::min(63u, vmcnt);
481 expcnt = std::min(7u, expcnt);
482 lgkmcnt = std::min(15u, lgkmcnt);
483 unsigned lowBits = vmcnt & 0xF;
484 unsigned highBits = (vmcnt >> 4) << 14;
485 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
486 return lowBits | highBits | otherCnts;
487 }
488 if (chipset.majorVersion == 10) {
489 vmcnt = std::min(63u, vmcnt);
490 expcnt = std::min(7u, expcnt);
491 lgkmcnt = std::min(63u, lgkmcnt);
492 unsigned lowBits = vmcnt & 0xF;
493 unsigned highBits = (vmcnt >> 4) << 14;
494 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
495 return lowBits | highBits | otherCnts;
496 }
497 if (chipset.majorVersion == 11) {
498 vmcnt = std::min(63u, vmcnt);
499 expcnt = std::min(7u, expcnt);
500 lgkmcnt = std::min(63u, lgkmcnt);
501 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
502 }
503 return failure();
504}
505
506struct MemoryCounterWaitOpLowering
507 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
508 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
509 Chipset chipset)
510 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
511 chipset(chipset) {}
512
513 Chipset chipset;
514
515 LogicalResult
516 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
517 ConversionPatternRewriter &rewriter) const override {
518 if (chipset.majorVersion >= 12) {
519 Location loc = op.getLoc();
520 if (std::optional<int> ds = adaptor.getDs())
521 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
522
523 if (std::optional<int> load = adaptor.getLoad())
524 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
525
526 if (std::optional<int> store = adaptor.getStore())
527 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
528
529 if (std::optional<int> exp = adaptor.getExp())
530 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
531
532 if (std::optional<int> tensor = adaptor.getTensor())
533 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
534
535 rewriter.eraseOp(op);
536 return success();
537 }
538
539 if (adaptor.getTensor())
540 return op.emitOpError("unsupported chipset");
541
542 auto getVal = [](Attribute attr) -> unsigned {
543 if (attr)
544 return cast<IntegerAttr>(attr).getInt();
545
546 // This value will be clamped to the maximum value for the chipset.
547 return 1024;
548 };
549 unsigned ds = getVal(adaptor.getDsAttr());
550 unsigned exp = getVal(adaptor.getExpAttr());
551
552 unsigned vmcnt = 1024;
553 Attribute load = adaptor.getLoadAttr();
554 Attribute store = adaptor.getStoreAttr();
555 if (load && store) {
556 vmcnt = getVal(load) + getVal(store);
557 } else if (load) {
558 vmcnt = getVal(load);
559 } else if (store) {
560 vmcnt = getVal(store);
561 }
562
563 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
564 if (failed(waitcnt))
565 return op.emitOpError("unsupported chipset");
566
567 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
568 return success();
569 }
570};
571
572struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
573 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
574 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
575
576 Chipset chipset;
577
578 LogicalResult
579 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
580 ConversionPatternRewriter &rewriter) const override {
581 Location loc = op.getLoc();
582 // This ensures that waits on global memory aren't introduced on
583 // chips that don't have the BackOffBarrier feature enabled in LLVM.
584 bool requiresInlineAsm = chipset < kGfx90a;
585
586 Attribute mmra =
587 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
588 // Note: while there *is* a workgroup-one-as scope, this, when combined with
589 // the MMRA, will lead to the fence having no effect. This is because the
590 // codepaths for an atomic load or store will observe that a
591 // one-address-space atomic to LDS requires no synchronization because
592 // operations on LDS are totally ordered with respect to each other, and so
593 // will not emit the correct waitcnt operations that these fences are
594 // intended to produce. Therefore, we use a broader type of fence and rely
595 // on the MMRA to relax it to the semantics we want.
596 StringRef scope = "workgroup";
597
598 auto relFence = LLVM::FenceOp::create(rewriter, loc,
599 LLVM::AtomicOrdering::release, scope);
600 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
601 if (requiresInlineAsm) {
602 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
603 LLVM::AsmDialect::AD_ATT);
604 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
605 const char *constraints = "";
606 LLVM::InlineAsmOp::create(
607 rewriter, loc,
608 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
609 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
610 /*is_align_stack=*/false, LLVM::TailCallKind::None,
611 /*asm_dialect=*/asmDialectAttr,
612 /*operand_attrs=*/ArrayAttr());
613 } else if (chipset.majorVersion < 12) {
614 ROCDL::SBarrierOp::create(rewriter, loc);
615 } else {
616 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
617 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
618 }
619
620 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
621 LLVM::AtomicOrdering::acquire, scope);
622 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
623 rewriter.replaceOp(op, acqFence);
624 return success();
625 }
626};
627
628struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
629 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
630 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
631
632 Chipset chipset;
633
634 LogicalResult
635 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
636 ConversionPatternRewriter &rewriter) const override {
637 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
638 (uint32_t)op.getOpts());
639 return success();
640 }
641};
642
643} // namespace
644
645/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
646/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
647///
648/// Specifically:
649/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
650/// allows bf16. Newer MFMAs support bf16 types on operand, check
651/// IntrinsicsAMDGPU.td file for reference.
652/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
653/// instead, which is what the f8f6f4 intrinsics use.
654/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
655/// integer.
656///
657/// Note that the type of `input` has already been LLVM type converted:
658/// therefore 8-bit and smaller floats are represented as their corresponding
659/// `iN` integers.
660static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
661 Location loc, Value input,
662 bool allowBf16 = true) {
663 Type inputType = input.getType();
664 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
665 if (vectorType.getElementType().isBF16() && !allowBf16)
666 return LLVM::BitcastOp::create(
667 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
668 if (vectorType.getElementType().isInteger(8) &&
669 vectorType.getNumElements() <= 8)
670 return LLVM::BitcastOp::create(
671 rewriter, loc,
672 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
673 if (isa<IntegerType>(vectorType.getElementType()) &&
674 vectorType.getElementTypeBitWidth() <= 8) {
675 int64_t numWords = llvm::divideCeil(
676 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
677 32);
678 return LLVM::BitcastOp::create(
679 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
680 input);
681 }
682 }
683 return input;
684}
685
686/// Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL
687/// types.
688static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter,
689 Location loc, Value input,
690 bool allowBf16 = true) {
691 Type inputType = input.getType();
692 auto vectorType = cast<VectorType>(inputType);
693 // bf16 -> i16 when not allowed (pre-gfx950).
694 if (vectorType.getElementType().isBF16() && !allowBf16)
695 return LLVM::BitcastOp::create(
696 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
697 // i8/fp8 vectors -> vector<Nxi32>.
698 if (isa<IntegerType>(vectorType.getElementType()) &&
699 vectorType.getElementTypeBitWidth() <= 8) {
700 int64_t numWords = llvm::divideCeil(
701 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
702 Type castType = (numWords > 1)
703 ? Type{VectorType::get(numWords, rewriter.getI32Type())}
704 : rewriter.getI32Type();
705 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
706 }
707 return input;
708}
709
710/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
711/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
712///
713/// Specifically:
714/// 1. If `input` is a i8 value, zero extend it to i32
715/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
716///
717/// Note that the type of `input` has already been LLVM type converted:
718/// therefore 8-bit and smaller floats are represented as their corresponding
719/// `iN` integers.
720static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
721 Value input) {
722 return TypeSwitch<Type, Value>(input.getType())
723 .Case([&](IntegerType) {
724 // Handle scalar i8: zero extend to i32.
725 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
726 input);
727 })
728 .Case([&](VectorType vectorType) {
729 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
730 int64_t numElements = vectorType.getNumElements();
731 assert((numElements == 4 || numElements == 8) &&
732 "scale operand must be a vector of length 4 or 8");
733 IntegerType outputType =
734 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
735 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
736 })
737 .DefaultUnreachable("unexpected input type for scale operand");
738}
739
740/// Maps f8 scale element types to WMMA scale format codes.
741static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
743 .Case([](Float8E8M0FNUType) { return 0; })
744 .Case([](Float8E4M3FNType) { return 2; })
745 .Default(std::nullopt);
746}
747
748/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
749/// and scale block size (16 or 32).
750static std::optional<StringRef>
752 if (m == 16 && n == 16 && k == 128)
753 return isScale16
754 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
755 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
756
757 if (m == 32 && n == 16 && k == 128)
758 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
759 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
760
761 return std::nullopt;
762}
763
764/// Push an input operand. If it is a float type, nothing to do. If it is
765/// an integer type, then we need to also push its signdness (1 for signed, 0
766/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
767/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
768/// We also need to convert bfloat inputs to i16 to account for the bfloat
769/// intrinsics having been defined before the AMD backend supported bfloat. We
770/// similarly need to pack 8-bit float types into integers as if they were i8
771/// (which they are for the backend's purposes).
773 ConversionPatternRewriter &rewriter, Location loc,
774 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
775 Value mlirInput, SmallVectorImpl<Value> &operands,
776 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
777 Type inputType = llvmInput.getType();
778 auto vectorType = dyn_cast<VectorType>(inputType);
779 if (!vectorType) {
780 operands.push_back(llvmInput);
781 return;
782 }
783 Type elemType = vectorType.getElementType();
784 if (elemType.getIntOrFloatBitWidth() > 8) {
785 operands.push_back(llvmInput);
786 return;
787 }
788
789 // We need to check the type of the input before conversion to properly test
790 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
791 // fp8/int8 information is lost during the conversion process.
792 auto mlirInputType = cast<VectorType>(mlirInput.getType());
793 bool isInputInteger = mlirInputType.getElementType().isInteger();
794 if (isInputInteger) {
795 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
796 bool localIsUnsigned = isUnsigned;
797 if (elemType.isUnsignedInteger()) {
798 localIsUnsigned = true;
799 } else if (elemType.isSignedInteger()) {
800 localIsUnsigned = false;
801 }
802 attrs.push_back(
803 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
804 }
805
806 int64_t numBits =
807 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
808 Type i32 = rewriter.getI32Type();
809 Type intrinsicInType = numBits <= 32
810 ? (Type)rewriter.getIntegerType(numBits)
811 : (Type)VectorType::get(numBits / 32, i32);
812 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
813 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
814 loc, llvmIntrinsicInType, llvmInput);
815 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
816 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
817 // Add in the zeros here.
818 if (numBits < 32)
819 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
820 operands.push_back(castInput);
821}
822
823/// Push the output operand. For many cases this is only pushing the output in
824/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
825/// since the same numbers of VGPRs is used, we need to decide if to store the
826/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
827/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
828/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
829/// as the instructions have been changed to return fewer registers instead.
830static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
831 Location loc,
832 const TypeConverter *typeConverter,
833 Value output, int32_t subwordOffset,
834 bool clamp, SmallVectorImpl<Value> &operands,
836 Type inputType = output.getType();
837 auto vectorType = dyn_cast<VectorType>(inputType);
838 Type elemType = vectorType.getElementType();
839 operands.push_back(output);
840 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
841 attrs.push_back(
842 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
843 } else if (elemType.isInteger(32)) {
844 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
845 }
846}
847
848/// Return true if `type` is the E5M2 variant of an 8-bit float that is
849/// supported by the `_bf8` instructions on the given `chipset`.
850static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
851 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
852 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
853}
854
855/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
856/// supported by the `_fp8` instructions on the given `chipset`.
857static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
858 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
859 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
860}
861
862/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
863/// if one exists. This includes checking to ensure the intrinsic is supported
864/// on the architecture you are compiling for.
865static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
866 Chipset chipset) {
867 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
868 b = mfma.getBlocks();
869 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
870 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
871
872 if (sourceElem.isF32() && destElem.isF32()) {
873 if (mfma.getReducePrecision() && chipset >= kGfx942) {
874 if (m == 32 && n == 32 && k == 4 && b == 1)
875 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
876 if (m == 16 && n == 16 && k == 8 && b == 1)
877 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
878 }
879 if (m == 32 && n == 32 && k == 1 && b == 2)
880 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
881 if (m == 16 && n == 16 && k == 1 && b == 4)
882 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
883 if (m == 4 && n == 4 && k == 1 && b == 16)
884 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
885 if (m == 32 && n == 32 && k == 2 && b == 1)
886 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
887 if (m == 16 && n == 16 && k == 4 && b == 1)
888 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
889 }
890
891 if (sourceElem.isF16() && destElem.isF32()) {
892 if (chipset >= kGfx950) {
893 if (m == 32 && n == 32 && k == 16 && b == 1)
894 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
895 if (m == 16 && n == 16 && k == 32 && b == 1)
896 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
897 }
898 if (m == 32 && n == 32 && k == 4 && b == 2)
899 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
900 if (m == 16 && n == 16 && k == 4 && b == 4)
901 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
902 if (m == 4 && n == 4 && k == 4 && b == 16)
903 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
904 if (m == 32 && n == 32 && k == 8 && b == 1)
905 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
906 if (m == 16 && n == 16 && k == 16 && b == 1)
907 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
908 }
909
910 if (sourceElem.isBF16() && destElem.isF32()) {
911 if (chipset >= kGfx950) {
912 if (m == 32 && n == 32 && k == 16 && b == 1)
913 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
914 if (m == 16 && n == 16 && k == 32 && b == 1)
915 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
916 }
917 if (chipset >= kGfx90a) {
918 if (m == 32 && n == 32 && k == 4 && b == 2)
919 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 4 && b == 4)
921 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
922 if (m == 4 && n == 4 && k == 4 && b == 16)
923 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
924 if (m == 32 && n == 32 && k == 8 && b == 1)
925 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
926 if (m == 16 && n == 16 && k == 16 && b == 1)
927 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
928 }
929 if (m == 32 && n == 32 && k == 2 && b == 2)
930 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 2 && b == 4)
932 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
933 if (m == 4 && n == 4 && k == 2 && b == 16)
934 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
935 if (m == 32 && n == 32 && k == 4 && b == 1)
936 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
937 if (m == 16 && n == 16 && k == 8 && b == 1)
938 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
939 }
940
941 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
942 if (chipset >= kGfx950) {
943 if (m == 32 && n == 32 && k == 32 && b == 1)
944 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
945 if (m == 16 && n == 16 && k == 64 && b == 1)
946 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
947 }
948 if (m == 32 && n == 32 && k == 4 && b == 2)
949 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
950 if (m == 16 && n == 16 && k == 4 && b == 4)
951 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
952 if (m == 4 && n == 4 && k == 4 && b == 16)
953 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
954 if (m == 32 && n == 32 && k == 8 && b == 1)
955 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
956 if (m == 16 && n == 16 && k == 16 && b == 1)
957 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
958 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
959 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
960 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
961 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
962 }
963
964 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
965 if (m == 16 && n == 16 && k == 4 && b == 1)
966 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
967 if (m == 4 && n == 4 && k == 4 && b == 4)
968 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
969 }
970
971 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
972 // Known to be correct because there are no scalar f8 instructions and
973 // because a length mismatch will have been caught by the verifier.
974 Type sourceBElem =
975 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
976 if (m == 16 && n == 16 && k == 32 && b == 1) {
977 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
978 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
979 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
980 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
981 }
982 if (m == 32 && n == 32 && k == 16 && b == 1) {
983 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
984 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
985 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
986 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
987 }
988 }
989
990 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
991 Type sourceBElem =
992 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
993 if (m == 16 && n == 16 && k == 32 && b == 1) {
994 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
995 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
996 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
997 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
998 }
999 if (m == 32 && n == 32 && k == 16 && b == 1) {
1000 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1001 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1002 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1003 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1004 }
1005 }
1006
1007 return std::nullopt;
1008}
1009
1010static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1012 .Case([](Float8E4M3FNType) { return 0u; })
1013 .Case([](Float8E5M2Type) { return 1u; })
1014 .Case([](Float6E2M3FNType) { return 2u; })
1015 .Case([](Float6E3M2FNType) { return 3u; })
1016 .Case([](Float4E2M1FNType) { return 4u; })
1017 .Default(std::nullopt);
1018}
1019
1020/// If there is a scaled MFMA instruction for the input element types `aType`
1021/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1022/// blocks) on the given `chipset`, return a tuple consisting of the
1023/// OperationName of the intrinsic and the type codes that need to be passed to
1024/// that intrinsic. Note that this is also used to implement some un-scaled
1025/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1026/// MFMA with a scale of 0.
1027static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1028mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1029 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1030 aType = getElementTypeOrSelf(aType);
1031 bType = getElementTypeOrSelf(bType);
1032 destType = getElementTypeOrSelf(destType);
1033
1034 if (chipset < kGfx950)
1035 return std::nullopt;
1036 if (!isa<Float32Type>(destType))
1037 return std::nullopt;
1038
1039 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1040 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1041 if (!aTypeCode || !bTypeCode)
1042 return std::nullopt;
1043
1044 if (m == 32 && n == 32 && k == 64 && b == 1)
1045 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1046 *aTypeCode, *bTypeCode};
1047 if (m == 16 && n == 16 && k == 128 && b == 1)
1048 return std::tuple{
1049 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1050 *bTypeCode};
1051
1052 return std::nullopt;
1053}
1054
1055static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1056mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1058 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1059 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1060 mfma.getBlocks(), chipset);
1061}
1062
1063static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1064mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1065 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1066 smfma.getSourceB().getType(),
1067 smfma.getDestC().getType(), smfma.getM(),
1068 smfma.getN(), smfma.getK(), 1u, chipset);
1069}
1070
1071/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1072/// for RDNA3/4 architectures.
1073static std::optional<StringRef>
1074wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1075 Type elemDestType, uint32_t k, bool isRDNA3) {
1076 using fp8 = Float8E4M3FNType;
1077 using bf8 = Float8E5M2Type;
1078
1079 // Handle k == 16 for RDNA3/4.
1080 if (k == 16) {
1081 // Common patterns for RDNA3 and RDNA4.
1082 if (elemSourceType.isF16() && elemDestType.isF32())
1083 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1084 if (elemSourceType.isBF16() && elemDestType.isF32())
1085 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1086 if (elemSourceType.isF16() && elemDestType.isF16())
1087 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1088 if (elemSourceType.isBF16() && elemDestType.isBF16())
1089 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1090 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1091 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1092
1093 // RDNA3 specific patterns.
1094 if (isRDNA3) {
1095 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1096 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1097 return std::nullopt;
1098 }
1099
1100 // RDNA4 specific patterns (fp8/bf8).
1101 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1102 elemDestType.isF32())
1103 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1104 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1105 elemDestType.isF32())
1106 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1107 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1108 elemDestType.isF32())
1109 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1110 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1111 elemDestType.isF32())
1112 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1113 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1114 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1115
1116 return std::nullopt;
1117 }
1118
1119 // Handle k == 32 for RDNA4.
1120 if (k == 32 && !isRDNA3) {
1121 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1122 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1123 }
1124
1125 return std::nullopt;
1126}
1127
1128/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1129/// for the gfx1250 architecture.
1130static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1131 Type elemBSourceType,
1132 Type elemDestType,
1133 uint32_t k) {
1134 using fp8 = Float8E4M3FNType;
1135 using bf8 = Float8E5M2Type;
1136
1137 if (k == 4) {
1138 if (elemSourceType.isF32() && elemDestType.isF32())
1139 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1140
1141 return std::nullopt;
1142 }
1143
1144 if (k == 32) {
1145 if (elemSourceType.isF16() && elemDestType.isF32())
1146 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1147 if (elemSourceType.isBF16() && elemDestType.isF32())
1148 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1149 if (elemSourceType.isF16() && elemDestType.isF16())
1150 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1151 if (elemSourceType.isBF16() && elemDestType.isBF16())
1152 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1153
1154 return std::nullopt;
1155 }
1156
1157 if (k == 64) {
1158 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1159 if (elemDestType.isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1161 if (elemDestType.isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1163 }
1164 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.isF32())
1166 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1167 if (elemDestType.isF16())
1168 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1169 }
1170 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1171 if (elemDestType.isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1173 if (elemDestType.isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1175 }
1176 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1177 if (elemDestType.isF32())
1178 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1179 if (elemDestType.isF16())
1180 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1181 }
1182 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1183 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1184
1185 return std::nullopt;
1186 }
1187
1188 if (k == 128) {
1189 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1190 if (elemDestType.isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1192 if (elemDestType.isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1194 }
1195 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.isF32())
1197 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1198 if (elemDestType.isF16())
1199 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1200 }
1201 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1202 if (elemDestType.isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1204 if (elemDestType.isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1206 }
1207 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1208 if (elemDestType.isF32())
1209 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1210 if (elemDestType.isF16())
1211 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1212 }
1213
1214 return std::nullopt;
1215 }
1216
1217 return std::nullopt;
1218}
1219
1220/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1221/// operation if one exists. This includes checking to ensure the intrinsic is
1222/// supported on the architecture you are compiling for.
1223static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1224 Chipset chipset) {
1225 bool isGfx950 = chipset >= kGfx950;
1226 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1227 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1228
1229 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1230 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1231 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1232 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1233
1234 if (m == 16 && n == 16 && k == 32) {
1235 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1236 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1237 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1238 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1239 }
1240
1241 if (m == 16 && n == 16 && k == 64) {
1242 if (isGfx950) {
1243 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1244 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1245 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1246 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1247 }
1248 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1249 destElem.isInteger(32))
1250 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1251 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1252 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1253 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1254 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1255 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1256 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1257 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1258 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1259 }
1260
1261 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1262 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1263 destElem.isInteger(32))
1264 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1265 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1266 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1267 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1268 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1269 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1270 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1271 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1272 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1273 }
1274
1275 if (m == 32 && n == 32 && k == 16) {
1276 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1277 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1278 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1279 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1280 }
1281
1282 if (m == 32 && n == 32 && k == 32) {
1283 if (isGfx950) {
1284 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1285 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1286 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1287 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1288 }
1289 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1290 destElem.isInteger(32))
1291 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1292 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1293 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1294 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1295 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1296 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1297 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1298 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1299 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1300 }
1301
1302 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1303 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1304 destElem.isInteger(32))
1305 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1306 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1307 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1308 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1309 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1310 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1311 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1312 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1313 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1314 }
1315
1316 return std::nullopt;
1317}
1318
1319/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1320/// if one exists. This includes checking to ensure the intrinsic is supported
1321/// on the architecture you are compiling for.
1322static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1323 Chipset chipset) {
1324 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1325 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1326 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1327 Type elemSourceType = sourceVectorType.getElementType();
1328 Type elemBSourceType = sourceBVectorType.getElementType();
1329 Type elemDestType = destVectorType.getElementType();
1330
1331 const uint32_t k = wmma.getK();
1332 const bool isRDNA3 = chipset.majorVersion == 11;
1333 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1334
1335 // Handle RDNA3 and RDNA4.
1336 if (isRDNA3 || isRDNA4)
1337 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1338 k, isRDNA3);
1339
1340 // Handle gfx1250.
1341 if (chipset == kGfx1250)
1342 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1343 elemDestType, k);
1344
1345 return std::nullopt;
1346}
1347
1348/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
1349/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
1350/// supported on the architecture you are compiling for.
1352 StringRef name;
1356};
1357
1358static std::optional<SparseWMMAOpInfo>
1359sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
1360 Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
1361 Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
1362 Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
1363
1364 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1365
1366 if ((m != 16) || (n != 16))
1367 return std::nullopt;
1368
1369 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1370 if (isRDNA4) {
1371 if (k == 32) {
1372 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1373 return SparseWMMAOpInfo{
1374 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
1375 false};
1376 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1377 return SparseWMMAOpInfo{
1378 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
1379 false};
1380 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1381 return SparseWMMAOpInfo{
1382 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
1383 false};
1384 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1385 return SparseWMMAOpInfo{
1386 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
1387 false};
1388 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1389 sourceBElem.isInteger(8))
1390 return SparseWMMAOpInfo{
1391 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
1392 true};
1393 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1394 sourceBElem.isInteger(4))
1395 return SparseWMMAOpInfo{
1396 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
1397 true};
1398 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1399 sourceBElem.isF8E4M3FN())
1400 return SparseWMMAOpInfo{
1401 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
1402 false, false};
1403 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1404 sourceBElem.isF8E5M2())
1405 return SparseWMMAOpInfo{
1406 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
1407 false, false};
1408 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1409 sourceBElem.isF8E4M3FN())
1410 return SparseWMMAOpInfo{
1411 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
1412 false, false};
1413 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1414 return SparseWMMAOpInfo{
1415 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
1416 false, false};
1417 }
1418 if (k == 64) {
1419 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1420 sourceBElem.isInteger(4))
1421 return SparseWMMAOpInfo{
1422 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
1423 true};
1424 }
1425 }
1426
1427 const bool isGFX1250 = chipset == kGfx1250;
1428 const bool isWavesize64 = swmmac.getWave64();
1429 if (isGFX1250 && !isWavesize64) {
1430 if (k == 64) {
1431 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1432 return SparseWMMAOpInfo{
1433 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
1434 false};
1435 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1436 return SparseWMMAOpInfo{
1437 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
1438 false};
1439 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1440 return SparseWMMAOpInfo{
1441 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
1442 false};
1443 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1444 return SparseWMMAOpInfo{
1445 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
1446 false};
1447 }
1448 if (k == 128) {
1449 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1450 sourceBElem.isF8E4M3FN())
1451 return SparseWMMAOpInfo{
1452 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
1453 true, false};
1454 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1455 sourceBElem.isF8E5M2())
1456 return SparseWMMAOpInfo{
1457 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
1458 true, false};
1459 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1460 sourceBElem.isF8E4M3FN())
1461 return SparseWMMAOpInfo{
1462 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
1463 true, false};
1464 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1465 return SparseWMMAOpInfo{
1466 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
1467 true, false};
1468 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1469 sourceBElem.isF8E4M3FN())
1470 return SparseWMMAOpInfo{
1471 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
1472 true, false};
1473 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1474 sourceBElem.isF8E5M2())
1475 return SparseWMMAOpInfo{
1476 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
1477 true, false};
1478 if (destElem.isF16() && sourceAElem.isF8E5M2() &&
1479 sourceBElem.isF8E4M3FN())
1480 return SparseWMMAOpInfo{
1481 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
1482 true, false};
1483 if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1484 return SparseWMMAOpInfo{
1485 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1486 true, false};
1487 if (destElem.isF16() && sourceAElem.isInteger(8) &&
1488 sourceBElem.isInteger(8))
1489 return SparseWMMAOpInfo{
1490 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1491 true, false};
1492 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1493 sourceBElem.isInteger(8))
1494 return SparseWMMAOpInfo{
1495 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
1496 true};
1497 }
1498 }
1499
1500 return std::nullopt;
1501}
1502
1503namespace {
1504struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1505 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1506 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1507
1508 Chipset chipset;
1509
1510 LogicalResult
1511 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1512 ConversionPatternRewriter &rewriter) const override {
1513 Location loc = op.getLoc();
1514 Type outType = typeConverter->convertType(op.getDestD().getType());
1515 Type intrinsicOutType = outType;
1516 if (auto outVecType = dyn_cast<VectorType>(outType))
1517 if (outVecType.getElementType().isBF16())
1518 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1519
1520 if (chipset.majorVersion != 9 || chipset < kGfx908)
1521 return op->emitOpError("MFMA only supported on gfx908+");
1522 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1523 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1524 if (chipset < kGfx942)
1525 return op.emitOpError("negation unsupported on older than gfx942");
1526 getBlgpField |=
1527 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1528 }
1529 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1530 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1531 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1532 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1533 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1534
1535 bool isScaled =
1536 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1537 if (isScaled &&
1538 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1539 return op.emitOpError(
1540 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1541 "be scaled as those fields are used for type information");
1542 }
1543
1544 StringRef intrinsicName =
1545 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1546 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1547 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1548 bool allowBf16 = [&]() {
1549 if (chipset < kGfx950)
1550 return false;
1551 if (isScaled)
1552 return true;
1553 return intrinsicName.contains("16x16x32.bf16") ||
1554 intrinsicName.contains("32x32x16.bf16");
1555 }();
1556 OperationState loweredOp(loc, intrinsicName);
1557 loweredOp.addTypes(intrinsicOutType);
1558 loweredOp.addOperands({packSmallFloatVectorOperand(
1559 rewriter, loc, adaptor.getSourceA(), allowBf16),
1561 rewriter, loc, adaptor.getSourceB(), allowBf16),
1562 adaptor.getDestC()});
1563 if (isScaled) {
1564 Value zero = createI32Constant(rewriter, loc, 0);
1565 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1566 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1567 loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1568 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1569 {"opselA", rewriter.getI32IntegerAttr(0)},
1570 {"opselB", rewriter.getI32IntegerAttr(0)}});
1571 } else {
1572 loweredOp.addAttributes(
1573 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1574 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1575 {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1576 };
1577 Value lowered = rewriter.create(loweredOp)->getResult(0);
1578 if (outType != intrinsicOutType)
1579 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1580 rewriter.replaceOp(op, lowered);
1581 return success();
1582 }
1583};
1584
1585struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1586 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1587 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1588
1589 Chipset chipset;
1590
1591 LogicalResult
1592 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1593 ConversionPatternRewriter &rewriter) const override {
1594 Location loc = op.getLoc();
1595 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1596
1597 if (chipset.majorVersion != 9 || chipset < kGfx950)
1598 return op->emitOpError("scaled MFMA only supported on gfx908+");
1599 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1600 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1601 if (!maybeScaledIntrinsic.has_value())
1602 return op.emitOpError(
1603 "no intrinsic matching scaled MFMA size on given chipset");
1604
1605 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1606 OperationState loweredOp(loc, intrinsicName);
1607 loweredOp.addTypes(intrinsicOutType);
1608 loweredOp.addOperands(
1609 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1610 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1611 adaptor.getDestC()});
1612 loweredOp.addOperands(
1613 {/*scales A*/
1614 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1615 /*scales B*/
1616 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1617 loweredOp.addAttributes(
1618 {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1619 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1620 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1621 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1622
1623 Value lowered = rewriter.create(loweredOp)->getResult(0);
1624 rewriter.replaceOp(op, lowered);
1625 return success();
1626 }
1627};
1628
1629struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1630 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1631 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1632
1633 Chipset chipset;
1634
1635 LogicalResult
1636 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1637 ConversionPatternRewriter &rewriter) const override {
1638 Location loc = op.getLoc();
1639 auto outType =
1640 typeConverter->convertType<VectorType>(op.getDestC().getType());
1641 if (!outType)
1642 return rewriter.notifyMatchFailure(op, "type conversion failed");
1643
1644 // smfmac is supported on gfx942 and gfx950.
1645 if (chipset.majorVersion != 9 || chipset < kGfx942)
1646 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1647 bool isGfx950 = chipset >= kGfx950;
1648
1649 Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
1650 isGfx950);
1651 Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
1652 isGfx950);
1653 Value c = adaptor.getDestC();
1654
1655 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1656 if (!maybeIntrinsic.has_value())
1657 return op.emitOpError(
1658 "no intrinsic matching sparse MFMA on the given chipset");
1659
1660 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1661 Value sparseIdx = LLVM::BitcastOp::create(
1662 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1663
1664 OperationState loweredOp(loc, maybeIntrinsic.value());
1665 loweredOp.addTypes(outType);
1666 loweredOp.addOperands({a, b, c, sparseIdx});
1667 loweredOp.addAttributes(
1668 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1669 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1670 Value lowered = rewriter.create(loweredOp)->getResult(0);
1671 rewriter.replaceOp(op, lowered);
1672 return success();
1673 }
1674};
1675
1676struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1677 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1678 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1679
1680 Chipset chipset;
1681
1682 LogicalResult
1683 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1684 ConversionPatternRewriter &rewriter) const override {
1685 Location loc = op.getLoc();
1686 auto outType =
1687 typeConverter->convertType<VectorType>(op.getDestD().getType());
1688 if (!outType)
1689 return rewriter.notifyMatchFailure(op, "type conversion failed");
1690
1691 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1692 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1693
1694 bool isGFX1250 = chipset >= kGfx1250;
1695
1696 // The WMMA operations represent vectors of bf16s as vectors of i16s
1697 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1698 // bitcast them back.
1699 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1700 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1701 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1702 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1703 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1704 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1705 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1706 VectorType rawOutType = outType;
1707 if (castOutToI16)
1708 rawOutType = outType.clone(rewriter.getI16Type());
1709 Value a = adaptor.getSourceA();
1710 if (castAToI16)
1711 a = LLVM::BitcastOp::create(rewriter, loc,
1712 aType.clone(rewriter.getI16Type()), a);
1713 Value b = adaptor.getSourceB();
1714 if (castBToI16)
1715 b = LLVM::BitcastOp::create(rewriter, loc,
1716 bType.clone(rewriter.getI16Type()), b);
1717 Value destC = adaptor.getDestC();
1718 if (castDestCToI16)
1719 destC = LLVM::BitcastOp::create(
1720 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1721
1722 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1723
1724 if (!maybeIntrinsic.has_value())
1725 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1726
1727 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1728 return op.emitOpError("subwordOffset not supported on gfx12+");
1729
1730 SmallVector<Value, 4> operands;
1731 SmallVector<NamedAttribute, 4> attrs;
1732 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1733 op.getSourceA(), operands, attrs, "signA");
1734 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1735 op.getSourceB(), operands, attrs, "signB");
1736 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1737 op.getSubwordOffset(), op.getClamp(), operands,
1738 attrs);
1739
1740 OperationState loweredOp(loc, *maybeIntrinsic);
1741 loweredOp.addTypes(rawOutType);
1742 loweredOp.addOperands(operands);
1743 loweredOp.addAttributes(attrs);
1744 Operation *lowered = rewriter.create(loweredOp);
1745
1746 Operation *maybeCastBack = lowered;
1747 if (rawOutType != outType)
1748 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1749 lowered->getResult(0));
1750 rewriter.replaceOp(op, maybeCastBack->getResults());
1751
1752 return success();
1753 }
1754};
1755
1756struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
1757 SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1758 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1759
1760 Chipset chipset;
1761
1762 LogicalResult
1763 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1764 ConversionPatternRewriter &rewriter) const override {
1765 Location loc = op.getLoc();
1766 auto outType =
1767 typeConverter->convertType<VectorType>(op.getDestD().getType());
1768 if (!outType)
1769 return rewriter.notifyMatchFailure(op, "type conversion failed");
1770
1771 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1772 sparseWMMAOpToIntrinsic(op, chipset);
1773
1774 if (!maybeIntrinsic.has_value())
1775 return op.emitOpError(
1776 "no intrinsic matching Sparse WMMA on the given chipset");
1777 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1778
1779 SmallVector<NamedAttribute> attrs;
1780
1781 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
1782 return op->emitOpError("intrinsic doesn't support unsign");
1783 if (intrinsic.useSign) {
1784 if (auto attr = op.getUnsignedAAttr())
1785 attrs.push_back({"signA", attr});
1786 if (auto attr = op.getUnsignedBAttr())
1787 attrs.push_back({"signB", attr});
1788 }
1789
1790 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
1791 return op->emitOpError("intrinsic doesn't support reuse");
1792 if (intrinsic.useReuse) {
1793 if (auto attr = op.getReuseAAttr())
1794 attrs.push_back({"reuseA", attr});
1795 if (auto attr = op.getReuseBAttr())
1796 attrs.push_back({"reuseB", attr});
1797 }
1798
1799 if (op.getClamp() && !intrinsic.useClamp)
1800 return op->emitOpError("intrinsic doesn't support clamp");
1801 if (intrinsic.useClamp && op.getClampAttr())
1802 attrs.push_back({"clamp", op.getClampAttr()});
1803
1804 const bool isGFX1250orHigher =
1805 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
1806 Value a = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceA(),
1807 isGFX1250orHigher);
1808 Value b = convertSparseVectorOperand(rewriter, loc, adaptor.getSourceB(),
1809 isGFX1250orHigher);
1810 Value c = adaptor.getDestC();
1811 VectorType rawOutType = outType;
1812 if (!isGFX1250orHigher) {
1813 c = convertSparseVectorOperand(rewriter, loc, adaptor.getDestC(), false);
1814 rawOutType = cast<VectorType>(c.getType());
1815 }
1816
1817 // Bitcast sparse indices from vector<4xi8> to i32.
1818 Value sparseIdx = LLVM::BitcastOp::create(
1819 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1820
1821 OperationState loweredOp(loc, intrinsic.name);
1822 loweredOp.addTypes(rawOutType);
1823 loweredOp.addOperands({a, b, c, sparseIdx});
1824 loweredOp.addAttributes(attrs);
1825 Operation *lowered = rewriter.create(loweredOp);
1826
1827 Operation *maybeCastBack = lowered;
1828 if (rawOutType != outType)
1829 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1830 lowered->getResult(0));
1831 rewriter.replaceOp(op, maybeCastBack->getResults());
1832
1833 return success();
1834 }
1835};
1836
1837struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
1838 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1839 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1840
1841 Chipset chipset;
1842
1843 LogicalResult
1844 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1845 ConversionPatternRewriter &rewriter) const override {
1846 Location loc = op.getLoc();
1847 auto outType =
1848 typeConverter->convertType<VectorType>(op.getDestD().getType());
1849 if (!outType)
1850 return rewriter.notifyMatchFailure(op, "type conversion failed");
1851
1852 if (chipset < kGfx1250)
1853 return op->emitOpError("WMMA scale only supported on gfx1250+");
1854
1855 int64_t m = op.getM();
1856 int64_t n = op.getN();
1857 int64_t k = op.getK();
1858
1859 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
1860 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
1861
1862 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
1863 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
1864
1865 if (!aFmtCode || !bFmtCode)
1866 return op.emitOpError("unsupported element types for scaled_wmma");
1867
1868 // Get scale vector types and determine variant (scale vs scale16).
1869 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1870 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1871
1872 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1873 return op.emitOpError("scaleA and scaleB must have equal vector length");
1874
1875 // Extract scale format from element types.
1876 Type scaleAElemType = scaleAVecType.getElementType();
1877 Type scaleBElemType = scaleBVecType.getElementType();
1878
1879 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
1880 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
1881
1882 if (!scaleAFmt || !scaleBFmt)
1883 return op.emitOpError("unsupported scale element types");
1884
1885 // Determine which intrinsic to use based on dimensions.
1886 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1887 std::optional<StringRef> intrinsicName =
1888 getScaledWmmaIntrinsicName(m, n, k, isScale16);
1889 if (!intrinsicName)
1890 return op.emitOpError("unsupported scaled_wmma dimensions: ")
1891 << m << "x" << n << "x" << k;
1892
1893 SmallVector<NamedAttribute, 8> attrs;
1894
1895 // The f4 variant does not have fmtA and fmtB attributes.
1896 bool is32x16 = (m == 32 && n == 16 && k == 128);
1897 if (!is32x16) {
1898 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1899 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1900 }
1901
1902 // modC uses default value of 0.
1903 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
1904
1905 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
1906 // half of the wave that is being selected (0 or 1).
1907 attrs.emplace_back(
1908 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1909 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1910 attrs.emplace_back(
1911 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1912 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1913
1914 // Reuse flags use default value of false.
1915 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1916 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
1917
1918 // Convert typed float vectors to packed format.
1919 Value sourceA =
1920 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
1921 Value sourceB =
1922 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
1923
1924 // Pack scale vectors into i32/i64.
1925 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
1926 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
1927
1928 // Create the intrinsic call.
1929 OperationState loweredOp(loc, *intrinsicName);
1930 loweredOp.addTypes(outType);
1931 loweredOp.addOperands(
1932 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1933 loweredOp.addAttributes(attrs);
1934
1935 Operation *lowered = rewriter.create(loweredOp);
1936 rewriter.replaceOp(op, lowered->getResults());
1937
1938 return success();
1939 }
1940};
1941
1942struct TransposeLoadOpLowering
1943 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1944 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1945 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1946
1947 Chipset chipset;
1948
1949 LogicalResult
1950 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1951 ConversionPatternRewriter &rewriter) const override {
1952 if (chipset != kGfx950)
1953 return op.emitOpError("Non-gfx950 chipset not supported");
1954
1955 Location loc = op.getLoc();
1956 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1957
1958 // Elements in subbyte memrefs are stored non-contiguously,
1959 // reject if source is sub-byte memref. Use emulated memrefs instead.
1960 size_t srcElementSize =
1961 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1962 if (srcElementSize < 8)
1963 return op.emitOpError("Expect source memref to have at least 8 bits "
1964 "element size, got ")
1965 << srcElementSize;
1966
1967 auto resultType = cast<VectorType>(op.getResult().getType());
1968 Value srcPtr =
1969 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1970 (adaptor.getSrcIndices()));
1971
1972 size_t numElements = resultType.getNumElements();
1973 size_t elementTypeSize =
1974 resultType.getElementType().getIntOrFloatBitWidth();
1975
1976 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1977 // the element size is smaller than 16 bits.
1978 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1979 rewriter.getIntegerType(32));
1980 Type llvmResultType = typeConverter->convertType(resultType);
1981
1982 switch (elementTypeSize) {
1983 case 4: {
1984 assert(numElements == 16);
1985 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1986 rocdlResultType, srcPtr);
1987 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1988 break;
1989 }
1990 case 6: {
1991 assert(numElements == 16);
1992 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1993 rocdlResultType, srcPtr);
1994 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1995 break;
1996 }
1997 case 8: {
1998 assert(numElements == 8);
1999 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2000 rocdlResultType, srcPtr);
2001 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2002 break;
2003 }
2004 case 16: {
2005 assert(numElements == 4);
2006 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2007 srcPtr);
2008 break;
2009 }
2010 default:
2011 return op.emitOpError("Unsupported element size for transpose load");
2012 }
2013 return success();
2014 }
2015};
2016
2017struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
2018 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2019 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2020
2021 Chipset chipset;
2022
2023 LogicalResult
2024 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2025 ConversionPatternRewriter &rewriter) const override {
2026 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2027 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
2028
2029 Location loc = op.getLoc();
2030
2031 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2032 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2033
2034 // TODO: instead of only transfering one element per thread, we could
2035 // augment it to transfer multiple elements per thread by issuing multiple
2036 // `global_load_lds` instructions.
2037 Type transferType = op.getTransferType();
2038 int loadWidth = [&]() -> int {
2039 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2040 return (transferVectorType.getNumElements() *
2041 transferVectorType.getElementTypeBitWidth()) /
2042 8;
2043 }
2044 return transferType.getIntOrFloatBitWidth() / 8;
2045 }();
2046
2047 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
2048 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2049 return op.emitOpError("chipset unsupported element size");
2050
2051 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2052 return op.emitOpError("Gather to LDS instructions with 12-byte and "
2053 "16-byte load widths are only supported on gfx950");
2054
2055 Value srcPtr =
2056 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2057 (adaptor.getSrcIndices()));
2058 Value dstPtr =
2059 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2060 (adaptor.getDstIndices()));
2061
2062 if (op.getAsync()) {
2063 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2064 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2065 /*offset=*/rewriter.getI32IntegerAttr(0),
2066 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2067 ArrayAttr{});
2068 } else {
2069 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2070 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2071 /*offset=*/rewriter.getI32IntegerAttr(0),
2072 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2073 ArrayAttr{});
2074 }
2075
2076 return success();
2077 }
2078};
2079
2080struct GlobalLoadAsyncToLDSOpLowering
2081 : public ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp> {
2082 GlobalLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
2083 Chipset chipset)
2084 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2085 chipset(chipset) {}
2086
2087 Chipset chipset;
2088
2089 LogicalResult
2090 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2091 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2092 ConversionPatternRewriter &rewriter) const override {
2093 if (chipset < kGfx1250)
2094 return op.emitOpError(
2095 "global_load_async_to_lds is only supported on gfx1250+");
2096
2097 Location loc = op.getLoc();
2098 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2099 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2100
2101 Type transferType = op.getTransferType();
2102 int transferBits =
2103 isa<VectorType>(transferType)
2104 ? cast<VectorType>(transferType).getNumElements() *
2105 cast<VectorType>(transferType).getElementTypeBitWidth()
2106 : transferType.getIntOrFloatBitWidth();
2107
2108 Value srcPtr =
2109 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2110 adaptor.getSrcIndices());
2111 Value dstPtr =
2112 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2113 adaptor.getDstIndices());
2114
2115 if (op.getMask()) {
2116 Value mask = adaptor.getMask();
2117 int64_t nullptrVal =
2118 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2119 Value nullInt =
2120 createI32Constant(rewriter, loc, static_cast<int32_t>(nullptrVal));
2121 Value nullPtr =
2122 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.getType(), nullInt);
2123 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2124 }
2125
2126 auto offset = rewriter.getI32IntegerAttr(0);
2127 auto aux = rewriter.getI32IntegerAttr(0);
2128
2129 switch (transferBits) {
2130 case 8:
2131 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2132 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2133 ArrayAttr{});
2134 break;
2135 case 32:
2136 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2137 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2138 ArrayAttr{});
2139 break;
2140 case 64:
2141 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2142 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2143 ArrayAttr{});
2144 break;
2145 case 128:
2146 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2147 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2148 ArrayAttr{});
2149 break;
2150 default:
2151 return op.emitOpError("unsupported transfer width");
2152 }
2153 return success();
2154 }
2155};
2156
2157namespace {
2158struct ExtPackedFp8OpLowering final
2159 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
2160 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2161 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2162 chipset(chipset) {}
2163 Chipset chipset;
2164
2165 LogicalResult
2166 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2167 ConversionPatternRewriter &rewriter) const override;
2168};
2169
2170struct ScaledExtPackedMatrixOpLowering final
2171 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
2172 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
2173 Chipset chipset)
2174 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2175 chipset(chipset) {}
2176 Chipset chipset;
2177
2178 LogicalResult
2179 matchAndRewrite(ScaledExtPackedMatrixOp op,
2180 ScaledExtPackedMatrixOpAdaptor adaptor,
2181 ConversionPatternRewriter &rewriter) const override;
2182};
2183
2184struct PackedTrunc2xFp8OpLowering final
2185 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
2186 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
2187 Chipset chipset)
2188 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2189 chipset(chipset) {}
2190 Chipset chipset;
2191
2192 LogicalResult
2193 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2194 ConversionPatternRewriter &rewriter) const override;
2195};
2196
2197struct PackedStochRoundFp8OpLowering final
2198 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
2199 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
2200 Chipset chipset)
2201 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2202 chipset(chipset) {}
2203 Chipset chipset;
2204
2205 LogicalResult
2206 matchAndRewrite(PackedStochRoundFp8Op op,
2207 PackedStochRoundFp8OpAdaptor adaptor,
2208 ConversionPatternRewriter &rewriter) const override;
2209};
2210
2211struct ScaledExtPackedOpLowering final
2212 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
2213 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2214 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2215 chipset(chipset) {}
2216 Chipset chipset;
2217
2218 LogicalResult
2219 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2220 ConversionPatternRewriter &rewriter) const override;
2221};
2222
2223struct PackedScaledTruncOpLowering final
2224 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
2225 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
2226 Chipset chipset)
2227 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2228 chipset(chipset) {}
2229 Chipset chipset;
2230
2231 LogicalResult
2232 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2233 ConversionPatternRewriter &rewriter) const override;
2234};
2235
2236} // end namespace
2237
2238LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2239 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2240 ConversionPatternRewriter &rewriter) const {
2241 Location loc = op.getLoc();
2242 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2243 return rewriter.notifyMatchFailure(
2244 loc, "Fp8 conversion instructions are not available on target "
2245 "architecture and their emulation is not implemented");
2246 Type v4i8 =
2247 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2248 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2249 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2250
2251 Value source = adaptor.getSource();
2252 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2253 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2254 Type sourceElemType = getElementTypeOrSelf(op.getSource());
2255 // Extend to a v4i8
2256 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2257 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2258 if (!sourceVecType) {
2259 longVec = LLVM::InsertElementOp::create(
2260 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2261 } else {
2262 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2263 Value idx = createI32Constant(rewriter, loc, i);
2264 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2265 longVec =
2266 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2267 }
2268 }
2269 source = longVec;
2270 }
2271 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2272 if (resultVecType) {
2273 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2274 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2275 op.getIndex());
2276 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2277 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2278 op.getIndex());
2279 }
2280 } else {
2281 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2282 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2283 op.getIndex());
2284 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2285 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2286 op.getIndex());
2287 }
2288 }
2289 return success();
2290}
2291
2292int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
2293 int32_t firstScaleByte) {
2294 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
2295 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
2296 // firstScaleByte are merged into a single attribute scaleSel. This is how
2297 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
2298 // attribute but is derifed from firstScaleLane).
2299 assert(llvm::is_contained({16, 32}, blockSize));
2300 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2301
2302 const bool isFp8 = bitWidth == 8;
2303 const bool isBlock16 = blockSize == 16;
2304
2305 if (!isFp8) {
2306 int32_t bit0 = isBlock16;
2307 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2308 int32_t bit1 = (firstScaleByte == 2) << 1;
2309 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2310 int32_t bit2 = scaleWaveHalf << 2;
2311 return bit2 | bit1 | bit0;
2312 }
2313
2314 int32_t bit0 = isBlock16;
2315 // firstScaleByte is guaranteed to be defined by two bits.
2316 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2317 int32_t bits2and1 = firstScaleByte << 1;
2318 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2319 int32_t bit3 = scaleWaveHalf << 3;
2320 int32_t bits = bit3 | bits2and1 | bit0;
2321 // These are invalid cases.
2322 assert(!llvm::is_contained(
2323 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2324 return bits;
2325}
2326
2327static std::optional<StringRef>
2328scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2329 using fp4 = Float4E2M1FNType;
2330 using fp8 = Float8E4M3FNType;
2331 using bf8 = Float8E5M2Type;
2332 using fp6 = Float6E2M3FNType;
2333 using bf6 = Float6E3M2FNType;
2334 if (isa<fp4>(srcElemType)) {
2335 if (destElemType.isF16())
2336 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2337 if (destElemType.isBF16())
2338 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2339 if (destElemType.isF32())
2340 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2341 return std::nullopt;
2342 }
2343 if (isa<fp8>(srcElemType)) {
2344 if (destElemType.isF16())
2345 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2346 if (destElemType.isBF16())
2347 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2348 if (destElemType.isF32())
2349 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2350 return std::nullopt;
2351 }
2352 if (isa<bf8>(srcElemType)) {
2353 if (destElemType.isF16())
2354 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2355 if (destElemType.isBF16())
2356 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2357 if (destElemType.isF32())
2358 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2359 return std::nullopt;
2360 }
2361 if (isa<fp6>(srcElemType)) {
2362 if (destElemType.isF16())
2363 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2364 if (destElemType.isBF16())
2365 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2366 if (destElemType.isF32())
2367 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2368 return std::nullopt;
2369 }
2370 if (isa<bf6>(srcElemType)) {
2371 if (destElemType.isF16())
2372 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2373 if (destElemType.isBF16())
2374 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2375 if (destElemType.isF32())
2376 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2377 return std::nullopt;
2378 }
2379 llvm_unreachable("invalid combination of element types for packed conversion "
2380 "instructions");
2381}
2382
2383LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2384 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2385 ConversionPatternRewriter &rewriter) const {
2386 using fp4 = Float4E2M1FNType;
2387 using fp8 = Float8E4M3FNType;
2388 using bf8 = Float8E5M2Type;
2389 using fp6 = Float6E2M3FNType;
2390 using bf6 = Float6E3M2FNType;
2391 Location loc = op.getLoc();
2392 if (chipset != kGfx1250) {
2393 return rewriter.notifyMatchFailure(
2394 loc,
2395 "Scaled fp packed conversion instructions are not available on target "
2396 "architecture and their emulation is not implemented");
2397 }
2398 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2399 // is being selected.
2400 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2401 int32_t firstScaleByte = op.getFirstScaleByte();
2402 int32_t blockSize = op.getBlockSize();
2403 auto sourceType = cast<VectorType>(op.getSource().getType());
2404 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2405 unsigned bitWidth = srcElemType.getWidth();
2406
2407 auto targetType = cast<VectorType>(op.getResult().getType());
2408 auto destElemType = cast<FloatType>(targetType.getElementType());
2409
2410 IntegerType i32 = rewriter.getI32Type();
2411 Value source = adaptor.getSource();
2412 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2413 Type packedType = nullptr;
2414 if (isa<fp4>(srcElemType)) {
2415 packedType = i32;
2416 packedType = getTypeConverter()->convertType(packedType);
2417 } else if (isa<fp8, bf8>(srcElemType)) {
2418 packedType = VectorType::get(2, i32);
2419 packedType = getTypeConverter()->convertType(packedType);
2420 } else if (isa<fp6, bf6>(srcElemType)) {
2421 packedType = VectorType::get(3, i32);
2422 packedType = getTypeConverter()->convertType(packedType);
2423 } else {
2424 llvm_unreachable("invalid element type for packed scaled ext");
2425 }
2426
2427 if (!packedType || !llvmResultType) {
2428 return rewriter.notifyMatchFailure(op, "type conversion failed");
2429 }
2430
2431 std::optional<StringRef> maybeIntrinsic =
2432 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2433 if (!maybeIntrinsic.has_value())
2434 return op.emitOpError(
2435 "no intrinsic matching packed scaled conversion on the given chipset");
2436
2437 int32_t scaleSel =
2438 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2439 Value castedScale =
2440 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2441 Value castedSource =
2442 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2443
2444 OperationState loweredOp(loc, *maybeIntrinsic);
2445 loweredOp.addTypes({llvmResultType});
2446 loweredOp.addOperands({castedSource, castedScale});
2447
2448 SmallVector<NamedAttribute, 1> attrs;
2449 attrs.push_back(
2450 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2451
2452 loweredOp.addAttributes(attrs);
2453 Operation *lowered = rewriter.create(loweredOp);
2454 rewriter.replaceOp(op, lowered);
2455
2456 return success();
2457}
2458
2459LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2460 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2461 ConversionPatternRewriter &rewriter) const {
2462 Location loc = op.getLoc();
2463 if (chipset != kGfx950)
2464 return rewriter.notifyMatchFailure(
2465 loc, "Scaled fp conversion instructions are not available on target "
2466 "architecture and their emulation is not implemented");
2467 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2468
2469 Value source = adaptor.getSource();
2470 Value scale = adaptor.getScale();
2471
2472 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2473 Type sourceElemType = sourceVecType.getElementType();
2474 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2475 Type destElemType = destVecType.getElementType();
2476
2477 VectorType packedVecType;
2478 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2479 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2480 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2481 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2482 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2483 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2484 } else {
2485 llvm_unreachable("invalid element type for scaled ext");
2486 }
2487
2488 // Extend to a packedVectorType
2489 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2490 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2491 if (!sourceVecType) {
2492 longVec = LLVM::InsertElementOp::create(
2493 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2494 } else {
2495 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2496 Value idx = createI32Constant(rewriter, loc, i);
2497 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2498 longVec =
2499 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2500 }
2501 }
2502 source = longVec;
2503 }
2504 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2505
2506 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2507 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2508 op, destVecType, i32Source, scale, op.getIndex());
2509 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2510 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2511 op, destVecType, i32Source, scale, op.getIndex());
2512 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2513 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2514 op, destVecType, i32Source, scale, op.getIndex());
2515 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2516 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2517 op, destVecType, i32Source, scale, op.getIndex());
2518 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2519 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2520 op, destVecType, i32Source, scale, op.getIndex());
2521 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2522 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2523 op, destVecType, i32Source, scale, op.getIndex());
2524 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2525 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2526 op, destVecType, i32Source, scale, op.getIndex());
2527 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2528 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2529 op, destVecType, i32Source, scale, op.getIndex());
2530 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2531 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2532 op, destVecType, i32Source, scale, op.getIndex());
2533 else
2534 return failure();
2535
2536 return success();
2537}
2538
2539LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2540 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2541 ConversionPatternRewriter &rewriter) const {
2542 Location loc = op.getLoc();
2543 if (chipset != kGfx950)
2544 return rewriter.notifyMatchFailure(
2545 loc, "Scaled fp conversion instructions are not available on target "
2546 "architecture and their emulation is not implemented");
2547 Type v2i16 = getTypeConverter()->convertType(
2548 VectorType::get(2, rewriter.getI16Type()));
2549 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2550
2551 Type resultType = op.getResult().getType();
2552 Type resultElemType = getElementTypeOrSelf(resultType);
2553 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2554 Type sourceElemType = sourceVecType.getElementType();
2555
2556 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2557
2558 Value source = adaptor.getSource();
2559 Value scale = adaptor.getScale();
2560 Value existing = adaptor.getExisting();
2561 if (existing)
2562 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2563 else
2564 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2565
2566 if (sourceVecType.getNumElements() < 2) {
2567 Value c0 = createI32Constant(rewriter, loc, 0);
2568 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2569 VectorType v2 = VectorType::get(2, sourceElemType);
2570 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2571 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2572 }
2573
2574 Value sourceA, sourceB;
2575 if (sourceElemType.isF32()) {
2576 Value c0 = createI32Constant(rewriter, loc, 0);
2577 Value c1 = createI32Constant(rewriter, loc, 1);
2578 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2579 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2580 }
2581
2582 Value result;
2583 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2584 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2585 existing, sourceA, sourceB,
2586 scale, op.getIndex());
2587 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2588 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2589 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2590 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2591 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2592 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2593 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2594 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2595 existing, sourceA, sourceB,
2596 scale, op.getIndex());
2597 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2598 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2599 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2600 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2601 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2602 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2603 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2604 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2605 existing, sourceA, sourceB,
2606 scale, op.getIndex());
2607 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2608 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2609 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2610 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2611 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2612 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2613 else
2614 return failure();
2615
2616 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2617 op, getTypeConverter()->convertType(resultType), result);
2618 return success();
2619}
2620
2621LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2622 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2623 ConversionPatternRewriter &rewriter) const {
2624 Location loc = op.getLoc();
2625 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2626 return rewriter.notifyMatchFailure(
2627 loc, "Fp8 conversion instructions are not available on target "
2628 "architecture and their emulation is not implemented");
2629 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2630
2631 Type resultType = op.getResult().getType();
2632 Type resultElemType = getElementTypeOrSelf(resultType);
2633
2634 Value sourceA = adaptor.getSourceA();
2635 Value sourceB = adaptor.getSourceB();
2636 if (!sourceB)
2637 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2638 Value existing = adaptor.getExisting();
2639 if (existing)
2640 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2641 else
2642 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2643
2644 Value result;
2645 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2646 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2647 existing, op.getWordIndex());
2648 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2649 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2650 existing, op.getWordIndex());
2651
2652 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2653 op, getTypeConverter()->convertType(resultType), result);
2654 return success();
2655}
2656
2657LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2658 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2659 ConversionPatternRewriter &rewriter) const {
2660 Location loc = op.getLoc();
2661 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2662 return rewriter.notifyMatchFailure(
2663 loc, "Fp8 conversion instructions are not available on target "
2664 "architecture and their emulation is not implemented");
2665 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2666
2667 Type resultType = op.getResult().getType();
2668 Type resultElemType = getElementTypeOrSelf(resultType);
2669
2670 Value source = adaptor.getSource();
2671 Value stoch = adaptor.getStochiasticParam();
2672 Value existing = adaptor.getExisting();
2673 if (existing)
2674 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2675 else
2676 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2677
2678 Value result;
2679 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2680 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2681 existing, op.getStoreIndex());
2682 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2683 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2684 existing, op.getStoreIndex());
2685
2686 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2687 op, getTypeConverter()->convertType(resultType), result);
2688 return success();
2689}
2690
2691// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2692// operation into the corresponding ROCDL instructions.
2693struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2694 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2695 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2696 Chipset chipset;
2697
2698 LogicalResult
2699 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2700 ConversionPatternRewriter &rewriter) const override {
2701
2702 // Convert the source operand to the corresponding LLVM type
2703 Location loc = DppOp.getLoc();
2704 Value src = adaptor.getSrc();
2705 Value old = adaptor.getOld();
2706 Type srcType = src.getType();
2707 Type oldType = old.getType();
2708 Type llvmType = nullptr;
2709 if (srcType.getIntOrFloatBitWidth() < 32) {
2710 llvmType = rewriter.getI32Type();
2711 } else if (isa<FloatType>(srcType)) {
2712 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2713 ? rewriter.getF32Type()
2714 : rewriter.getF64Type();
2715 } else if (isa<IntegerType>(srcType)) {
2716 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2717 ? rewriter.getI32Type()
2718 : rewriter.getI64Type();
2719 }
2720 auto llvmSrcIntType = typeConverter->convertType(
2721 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2722
2723 // If the source type is less of 32, use bitcast to convert it to i32.
2724 auto convertOperand = [&](Value operand, Type operandType) {
2725 if (operandType.getIntOrFloatBitWidth() <= 16) {
2726 if (llvm::isa<FloatType>(operandType)) {
2727 operand =
2728 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2729 }
2730 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2731 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2732 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2733 operand =
2734 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2735 createI32Constant(rewriter, loc, 0));
2736 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2737 }
2738 return operand;
2739 };
2740
2741 src = convertOperand(src, srcType);
2742 old = convertOperand(old, oldType);
2743
2744 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2745 enum DppCtrl : unsigned {
2746 ROW_SHL0 = 0x100,
2747 ROW_SHR0 = 0x110,
2748 ROW_ROR0 = 0x120,
2749 WAVE_SHL1 = 0x130,
2750 WAVE_ROL1 = 0x134,
2751 WAVE_SHR1 = 0x138,
2752 WAVE_ROR1 = 0x13C,
2753 ROW_MIRROR = 0x140,
2754 ROW_HALF_MIRROR = 0x141,
2755 BCAST15 = 0x142,
2756 BCAST31 = 0x143,
2757 };
2758
2759 auto kind = DppOp.getKind();
2760 auto permArgument = DppOp.getPermArgument();
2761 uint32_t DppCtrl = 0;
2762
2763 switch (kind) {
2764
2765 case DPPPerm::quad_perm: {
2766 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2767 int32_t i = 0;
2768 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2769 uint32_t num = elem.getInt();
2770 DppCtrl |= num << (i * 2);
2771 i++;
2772 }
2773 break;
2774 }
2775 case DPPPerm::row_shl: {
2776 auto intAttr = cast<IntegerAttr>(*permArgument);
2777 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2778 break;
2779 }
2780 case DPPPerm::row_shr: {
2781 auto intAttr = cast<IntegerAttr>(*permArgument);
2782 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2783 break;
2784 }
2785 case DPPPerm::row_ror: {
2786 auto intAttr = cast<IntegerAttr>(*permArgument);
2787 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2788 break;
2789 }
2790 case DPPPerm::wave_shl:
2791 DppCtrl = DppCtrl::WAVE_SHL1;
2792 break;
2793 case DPPPerm::wave_shr:
2794 DppCtrl = DppCtrl::WAVE_SHR1;
2795 break;
2796 case DPPPerm::wave_rol:
2797 DppCtrl = DppCtrl::WAVE_ROL1;
2798 break;
2799 case DPPPerm::wave_ror:
2800 DppCtrl = DppCtrl::WAVE_ROR1;
2801 break;
2802 case DPPPerm::row_mirror:
2803 DppCtrl = DppCtrl::ROW_MIRROR;
2804 break;
2805 case DPPPerm::row_half_mirror:
2806 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2807 break;
2808 case DPPPerm::row_bcast_15:
2809 DppCtrl = DppCtrl::BCAST15;
2810 break;
2811 case DPPPerm::row_bcast_31:
2812 DppCtrl = DppCtrl::BCAST31;
2813 break;
2814 }
2815
2816 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
2817 // constants
2818 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
2819 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
2820 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
2821
2822 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
2823 auto dppMovOp =
2824 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2825 rowMask, bankMask, boundCtrl);
2826
2827 Value result = dppMovOp.getRes();
2828 if (srcType.getIntOrFloatBitWidth() < 32) {
2829 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
2830 if (!llvm::isa<IntegerType>(srcType)) {
2831 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
2832 }
2833 }
2834
2835 // We are replacing the AMDGPU_DPPOp instruction with the new
2836 // ROCDL_DPPMovOp instruction
2837 rewriter.replaceOp(DppOp, ValueRange(result));
2838 return success();
2839 }
2840};
2841
2842struct AMDGPUSwizzleBitModeLowering
2843 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2845
2846 LogicalResult
2847 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2848 ConversionPatternRewriter &rewriter) const override {
2849 Location loc = op.getLoc();
2850 Type i32 = rewriter.getI32Type();
2851 Value src = adaptor.getSrc();
2852 SmallVector<Value> decomposed;
2853 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2854 return rewriter.notifyMatchFailure(op,
2855 "failed to decompose value to i32");
2856 unsigned andMask = op.getAndMask();
2857 unsigned orMask = op.getOrMask();
2858 unsigned xorMask = op.getXorMask();
2859
2860 // bit 15 is 0 for the BitMode swizzle.
2861 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2862 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2863 Value maskValue = createI32Constant(rewriter, loc, mask);
2864 SmallVector<Value> swizzled;
2865 for (Value v : decomposed) {
2866 Value res =
2867 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2868 swizzled.emplace_back(res);
2869 }
2870
2871 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2872 rewriter.replaceOp(op, result);
2873 return success();
2874 }
2875};
2876
2877struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2879
2880 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2881 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2882 Chipset chipset;
2883
2884 LogicalResult
2885 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2886 ConversionPatternRewriter &rewriter) const override {
2887 if (chipset < kGfx950)
2888 return op->emitOpError("permlane_swap is only supported on gfx950+");
2889
2890 Location loc = op.getLoc();
2891 Type i32 = rewriter.getI32Type();
2892 Value src = adaptor.getSrc();
2893 unsigned rowLength = op.getRowLength();
2894 bool fi = op.getFetchInactive();
2895 bool boundctrl = op.getBoundCtrl();
2896
2897 SmallVector<Value> decomposed;
2898 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2899 return rewriter.notifyMatchFailure(op,
2900 "failed to decompose value to i32");
2901
2902 SmallVector<Value> permuted;
2903 for (Value v : decomposed) {
2904 Value res;
2905 Type i32pair = LLVM::LLVMStructType::getLiteral(
2906 rewriter.getContext(), {v.getType(), v.getType()});
2907
2908 if (rowLength == 16)
2909 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2910 boundctrl);
2911 else if (rowLength == 32)
2912 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2913 boundctrl);
2914 else
2915 llvm_unreachable("unsupported row length");
2916
2917 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2918 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2919
2920 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2921 LLVM::ICmpPredicate::eq, vdst0, v);
2922
2923 // Per `permlane(16|32)` semantics: if the first extracted element equals
2924 // 'v', the result is the second element; otherwise it is the first.
2925 Value vdstNew =
2926 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2927 permuted.emplace_back(vdstNew);
2928 }
2929
2930 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2931 rewriter.replaceOp(op, result);
2932 return success();
2933 }
2934};
2935
2936//===----------------------------------------------------------------------===//
2937// In-LDS Barrier Operations
2938//===----------------------------------------------------------------------===//
2939
2940// Bit layout of ds_barrier_state (as i64):
2941// [63:32] init count (32 bits)
2942// [31:29] phase (3 bits)
2943// [28:0] pending count (29 bits)
2944constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
2945constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
2946constexpr int32_t kDsBarrierInitCountPos = 32;
2947constexpr int32_t kDsBarrierPendingCountMask =
2948 (1 << kDsBarrierPendingCountBitWidth) - 1;
2949
2950struct DsBarrierInitOpLowering
2951 : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
2952 Chipset chipset;
2953
2954 DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2955 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
2956
2957 LogicalResult
2958 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
2959 ConversionPatternRewriter &rewriter) const override {
2960 if (chipset < kGfx1250)
2961 return op->emitOpError("only supported on gfx1250+");
2962
2963 Location loc = op.getLoc();
2964 Type i64 = rewriter.getI64Type();
2965
2966 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2967 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
2968 adaptor.getBase(), adaptor.getIndices());
2969
2970 // Note: We give participants as the number of arrivals that have to occur
2971 // before the phase changes. Hardware changes the phase when updating the
2972 // pending count would underflow, so we subtract 1 to get the behavior we're
2973 // looking for.
2974 Value initCount =
2975 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
2976 createI32Constant(rewriter, loc, 1));
2977
2978 // Just a bit of paranoia, but this also allows for configurable width if
2979 // that becomes a thing.
2980 Value countMask =
2981 createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
2982 Value maskedCount32 =
2983 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
2984 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
2985
2986 Value initCountShifted = LLVM::ShlOp::create(
2987 rewriter, loc, maskedCount,
2988 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
2989 Value barrierState =
2990 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
2991
2992 LLVM::StoreOp::create(
2993 rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
2994 /*isNonTemporal=*/false,
2995 /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
2996 /*syncscope=*/"workgroup");
2997
2998 rewriter.eraseOp(op);
2999 return success();
3000 }
3001};
3002
3003struct DsBarrierPollStateOpLowering
3004 : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3005 Chipset chipset;
3006
3007 DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
3008 Chipset chipset)
3009 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3010 chipset(chipset) {}
3011
3012 LogicalResult
3013 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3014 ConversionPatternRewriter &rewriter) const override {
3015 if (chipset < kGfx1250)
3016 return op->emitOpError("only supported on gfx1250+");
3017
3018 Location loc = op.getLoc();
3019 Type i64 = rewriter.getI64Type();
3020
3021 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3022 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3023 adaptor.getBase(), adaptor.getIndices());
3024
3025 // Atomic load with workgroup scope and acquire ordering should be what
3026 // we're looking for.
3027 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3028 op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
3029 /*nontemporal=*/false, /*invariant=*/false,
3030 /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
3031 /*syncscope=*/"workgroup");
3032 return success();
3033 }
3034};
3035
3036struct DsAsyncBarrierArriveOpLowering
3037 : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3038 Chipset chipset;
3039
3040 DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
3041 Chipset chipset)
3042 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3043 chipset(chipset) {}
3044
3045 LogicalResult
3046 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3047 ConversionPatternRewriter &rewriter) const override {
3048 if (chipset < kGfx1250)
3049 return op->emitOpError("only supported on gfx1250+");
3050
3051 Location loc = op.getLoc();
3052
3053 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3054 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3055 adaptor.getBase(), adaptor.getIndices());
3056
3057 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3058 op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
3059 /*tbaa=*/nullptr);
3060 return success();
3061 }
3062};
3063
3064struct DsBarrierArriveOpLowering
3065 : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3066 Chipset chipset;
3067
3068 DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3069 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3070 }
3071
3072 LogicalResult
3073 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3074 ConversionPatternRewriter &rewriter) const override {
3075 if (chipset < kGfx1250)
3076 return op->emitOpError("only supported on gfx1250+");
3077
3078 Location loc = op.getLoc();
3079 Type i64 = rewriter.getI64Type();
3080
3081 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3082 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3083 adaptor.getBase(), adaptor.getIndices());
3084
3085 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3086 op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
3087 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3088 return success();
3089 }
3090};
3091
3092struct DsBarrierStatePhaseOpLowering
3093 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3095
3096 LogicalResult
3097 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3098 ConversionPatternRewriter &rewriter) const override {
3099 Location loc = op.getLoc();
3100 Type i32 = rewriter.getI32Type();
3101
3102 Value state = adaptor.getState();
3103
3104 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3105 Value phase = LLVM::LShrOp::create(
3106 rewriter, loc, noInitCount,
3107 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3108
3109 rewriter.replaceOp(op, phase);
3110 return success();
3111 }
3112};
3113
3114struct DsBarrierStatePendingCountOpLowering
3115 : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3117
3118 LogicalResult
3119 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3120 ConversionPatternRewriter &rewriter) const override {
3121 Location loc = op.getLoc();
3122 Type i32 = rewriter.getI32Type();
3123
3124 Value state = adaptor.getState();
3125
3126 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3127 Value pendingCount = LLVM::AndOp::create(
3128 rewriter, loc, noInitCount,
3129 createI32Constant(rewriter, loc,
3130 static_cast<uint32_t>(kDsBarrierPendingCountMask)));
3131
3132 rewriter.replaceOp(op, pendingCount);
3133 return success();
3134 }
3135};
3136
3137struct DsBarrierStateInitCountOpLowering
3138 : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3140
3141 LogicalResult
3142 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3143 ConversionPatternRewriter &rewriter) const override {
3144 Location loc = op.getLoc();
3145 Type i32 = rewriter.getI32Type();
3146
3147 Value state = adaptor.getState();
3148
3149 Value initCountI64 = LLVM::LShrOp::create(
3150 rewriter, loc, state,
3151 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3152 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3153
3154 rewriter.replaceOp(op, initCount);
3155 return success();
3156 }
3157};
3158
3159struct DsBarrierStatePhaseParityLowering
3160 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3162
3163 LogicalResult
3164 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3165 ConversionPatternRewriter &rewriter) const override {
3166 Location loc = op.getLoc();
3167 Type i1 = rewriter.getI1Type();
3168
3169 Value state = adaptor.getState();
3170
3171 Value noInitCount =
3172 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3173 Value phase = LLVM::LShrOp::create(
3174 rewriter, loc, noInitCount,
3175 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3176 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3177
3178 rewriter.replaceOp(op, parity);
3179 return success();
3180 }
3181};
3182
3183//===----------------------------------------------------------------------===//
3184// Tensor Data Mover (TDM)
3185//===----------------------------------------------------------------------===//
3186
3187static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3188 Value accumulator, Value value, int64_t shift) {
3189 shift = shift % 32;
3190 Value shiftAmount;
3191 if (shift != 0) {
3192 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
3193 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3194 }
3195
3196 if (matchPattern(accumulator, mlir::m_Zero()))
3197 return value;
3198
3199 constexpr bool isDisjoint = true;
3200 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3201}
3202
3203template <typename BaseOp>
3204struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
3205 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3206 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
3207
3208 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
3209 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3210 Chipset chipset;
3211
3212 LogicalResult
3213 matchAndRewrite(BaseOp op, Adaptor adaptor,
3214 ConversionPatternRewriter &rewriter) const override {
3215 if (chipset < kGfx1250)
3216 return op->emitOpError("make_dma_base is only supported on gfx1250");
3217
3218 Location loc = op.getLoc();
3219
3220 constexpr int32_t constlen = 4;
3221 Value consts[constlen];
3222 for (int64_t i = 0; i < constlen; ++i)
3223 consts[i] = createI32Constant(rewriter, loc, i);
3224
3225 constexpr int32_t sgprslen = constlen;
3226 Value sgprs[sgprslen];
3227 for (int64_t i = 0; i < sgprslen; ++i) {
3228 sgprs[i] = consts[0];
3229 }
3230
3231 sgprs[0] = consts[1];
3232
3233 if constexpr (BaseOp::isGather()) {
3234 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3235
3236 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3237 Type indexType = type.getIndexType();
3238 unsigned indexSize = indexType.getIntOrFloatBitWidth();
3239 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3240 "expected index_size to be 16 or 32");
3241 unsigned idx = (indexSize / 16) - 1;
3242
3243 if (idx)
3244 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3245 }
3246
3247 ValueRange ldsIndices = adaptor.getLdsIndices();
3248 Value lds = adaptor.getLds();
3249 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3250
3252 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3253
3254 ValueRange globalIndices = adaptor.getGlobalIndices();
3255 Value global = adaptor.getGlobal();
3256 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3257
3259 rewriter, loc, globalMemRefType, global, globalIndices);
3260
3261 Type i32 = rewriter.getI32Type();
3262 Type i64 = rewriter.getI64Type();
3263
3264 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3265 Value castForGlobalAddr =
3266 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3267
3268 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3269
3270 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3271 createI64Constant(rewriter, loc, 32));
3272
3273 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3274
3275 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
3276 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3277
3278 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3279
3280 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3281 assert(v4i32 && "expected type conversion to succeed");
3282 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3283
3284 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3285 result =
3286 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
3287
3288 rewriter.replaceOp(op, result);
3289 return success();
3290 }
3291};
3292
3293template <typename DescriptorOp>
3294struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
3295 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3296 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
3297
3298 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
3299 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3300 Chipset chipset;
3301
3302 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
3303
3304 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3305 ConversionPatternRewriter &rewriter, Location loc,
3306 Value sgpr0) const {
3307 Value mask = op.getWorkgroupMask();
3308 if (!mask)
3309 return sgpr0;
3310
3311 Type i16 = rewriter.getI16Type();
3312 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3313 Type i32 = rewriter.getI32Type();
3314 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3315 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3316 }
3317
3318 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3319 ConversionPatternRewriter &rewriter, Location loc,
3320 Value sgpr0, ArrayRef<Value> consts) const {
3321 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3322 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3323 "expected type width to be 8, 16, 32, or 64.");
3324 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3325 Value size = consts[idx];
3326 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3327 }
3328
3329 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3330 ConversionPatternRewriter &rewriter, Location loc,
3331 Value sgpr0, ArrayRef<Value> consts) const {
3332 if (!adaptor.getAtomicBarrierAddress())
3333 return sgpr0;
3334
3335 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3336 }
3337
3338 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3339 ConversionPatternRewriter &rewriter, Location loc,
3340 Value sgpr0, ArrayRef<Value> consts) const {
3341 if (!adaptor.getGlobalIncrement())
3342 return sgpr0;
3343
3344 // Value is ignored when in gather mode.
3345 // TODO: emit error earlier?
3346 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3347 }
3348
3349 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3350 ConversionPatternRewriter &rewriter, Location loc,
3351 Value sgpr0, ArrayRef<Value> consts) const {
3352 if (!op.getPadAmount())
3353 return sgpr0;
3354
3355 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3356 }
3357
3358 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3359 ConversionPatternRewriter &rewriter, Location loc,
3360 Value sgpr0, ArrayRef<Value> consts) const {
3361 if (!op.getWorkgroupMask())
3362 return sgpr0;
3363
3364 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3365 }
3366
3367 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3368 ConversionPatternRewriter &rewriter, Location loc,
3369 Value sgpr0, ArrayRef<Value> consts) const {
3370 if (!op.getPadAmount())
3371 return sgpr0;
3372
3373 // pre-condition: padInterval can be a power of two between 2 and 256.
3374 // TODO: Validation if the value breaks the pre-condition.
3375 // If the pre-condition fails, there is a possibility of
3376 // affecting the higher bits. In a following PR implement
3377 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3378 // checked at runtime.
3379 IntegerType i32 = rewriter.getI32Type();
3380 Value padInterval = adaptor.getPadInterval();
3381 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3382 padInterval, false);
3383 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3384 // post-condition: padInterval can be a value between 0 and 7.
3385 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3386 }
3387
3388 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3389 ConversionPatternRewriter &rewriter, Location loc,
3390 Value sgpr0, ArrayRef<Value> consts) const {
3391 if (!op.getPadAmount())
3392 return sgpr0;
3393
3394 // pre-condition: padAmount is a value between 1-128.
3395 // TODO: Validation if the value breaks the pre-condition.
3396 // If the pre-condition fails, there is a possibility of
3397 // affecting the higher bits. In a following PR implement
3398 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3399 // checked at runtime.
3400 Value padAmount = adaptor.getPadAmount();
3401 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3402 // post-condition: padAmount is a value between 0-127.
3403 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3404 }
3405
3406 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3407 ConversionPatternRewriter &rewriter,
3408 Location loc, Value sgpr1,
3409 ArrayRef<Value> consts) const {
3410 if (!adaptor.getAtomicBarrierAddress())
3411 return sgpr1;
3412
3413 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3414 auto barrierAddressTy =
3415 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3416 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3417 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
3418 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3419 atomicBarrierIndices);
3420 IntegerType i32 = rewriter.getI32Type();
3421 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
3422 // that the 3 LSBs are zero.
3423 // TODO: Validation if the value breaks the pre-condition.
3424 // In a following PR implement RuntimeVerifiableOpInterface
3425 // that instruments conditions that need to be checked at runtime.
3426 atomicBarrierAddress =
3427 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3428 atomicBarrierAddress =
3429 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3430 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
3431 atomicBarrierAddress =
3432 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3433 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3434 }
3435
3436 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3437 ConversionPatternRewriter &rewriter,
3438 Location loc, Value sgpr1, Value sgpr2,
3439 ArrayRef<Value> consts, uint64_t dimX,
3440 uint32_t offset) const {
3441 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3442 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3443 SmallVector<OpFoldResult> mixedGlobalSizes =
3444 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3445 if (mixedGlobalSizes.size() <= dimX)
3446 return {sgpr1, sgpr2};
3447
3448 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3449 // pre-condition: tensorDimX is less than 2^32-1
3450 // TODO: Validation if the value breaks the pre-condition.
3451 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3452 // conditions that need to be checked at runtime. This could also be fixed
3453 // by saying that mixedGlobalSizes is a DynamicI32List.
3454 Value tensorDimX;
3455 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3456 tensorDimX =
3457 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3458 } else {
3459 IntegerType i32 = rewriter.getI32Type();
3460 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3461 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3462 }
3463
3464 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3465
3466 Value c16 = createI32Constant(rewriter, loc, 16);
3467 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3468 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3469 return {sgpr1, sgpr2};
3470 }
3471
3472 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3473 ConversionPatternRewriter &rewriter,
3474 Location loc, Value sgpr1, Value sgpr2,
3475 ArrayRef<Value> consts) const {
3476 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3477 48);
3478 }
3479
3480 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3481 ConversionPatternRewriter &rewriter,
3482 Location loc, Value sgpr2, Value sgpr3,
3483 ArrayRef<Value> consts) const {
3484 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3485 80);
3486 }
3487
3488 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3489 ConversionPatternRewriter &rewriter, Location loc,
3490 Value sgpr, ArrayRef<Value> consts, size_t dimX,
3491 int64_t offset) const {
3492 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3493 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3494 SmallVector<OpFoldResult> mixedSharedSizes =
3495 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
3496 if (mixedSharedSizes.size() <= dimX)
3497 return sgpr;
3498
3499 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3500 // pre-condition: tileDimX is less than 2^16-1
3501 // TODO: Validation if the value breaks the pre-condition.
3502 // If the pre-condition fails, there is a possibility of
3503 // affecting the higher bits. In a following PR implement
3504 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3505 // checked at runtime. This could also be fixed by saying that
3506 // mixedSharedSizes is a DynamicI16List.
3507 Value tileDimX;
3508 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3509 tileDimX =
3510 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3511 } else {
3512 IntegerType i32 = rewriter.getI32Type();
3513 tileDimX = cast<Value>(tileDimXOpFoldResult);
3514 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3515 }
3516
3517 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3518 }
3519
3520 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3521 ConversionPatternRewriter &rewriter, Location loc,
3522 Value sgpr3, ArrayRef<Value> consts) const {
3523 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3524 }
3525
3526 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3527 ConversionPatternRewriter &rewriter, Location loc,
3528 Value sgpr4, ArrayRef<Value> consts) const {
3529 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3530 }
3531
3532 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3533 ConversionPatternRewriter &rewriter, Location loc,
3534 Value sgpr4, ArrayRef<Value> consts) const {
3535 auto type = cast<VectorType>(op.getIndices().getType());
3536 ArrayRef<int64_t> shape = type.getShape();
3537 assert(shape.size() == 1 && "expected shape to be of rank 1.");
3538 unsigned length = shape.back();
3539 assert(0 < length && length <= 16 && "expected length to be at most 16.");
3540 Value value = createI32Constant(rewriter, loc, length);
3541 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3542 }
3543
3544 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3545 ConversionPatternRewriter &rewriter,
3546 Location loc, Value sgpr4,
3547 ArrayRef<Value> consts) const {
3548 if constexpr (DescriptorOp::isGather())
3549 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3550 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3551 }
3552
3553 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3554 ConversionPatternRewriter &rewriter, Location loc,
3555 Value sgpr4, ArrayRef<Value> consts) const {
3556 // Value is ignored when in gather mode.
3557 if constexpr (DescriptorOp::isGather())
3558 return sgpr4;
3559 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3560 }
3561
3562 std::pair<Value, Value>
3563 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3564 ConversionPatternRewriter &rewriter, Location loc,
3565 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3566 size_t dimX, int64_t offset) const {
3567 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3568 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3569 SmallVector<OpFoldResult> mixedGlobalStrides =
3570 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3571
3572 if (mixedGlobalStrides.size() <= (dimX + 1))
3573 return {sgprY, sgprZ};
3574
3575 OpFoldResult tensorDimXStrideOpFoldResult =
3576 *(mixedGlobalStrides.rbegin() + dimX + 1);
3577 // pre-condition: tensorDimXStride is less than 2^48-1
3578 // TODO: Validation if the value breaks the pre-condition.
3579 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3580 // conditions that need to be checked at runtime.
3581 Value tensorDimXStride;
3582 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3583 tensorDimXStride =
3584 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3585 else
3586 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3587
3588 constexpr int64_t first48bits = (1ll << 48) - 1;
3589 Value mask = createI64Constant(rewriter, loc, first48bits);
3590 tensorDimXStride =
3591 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3592 IntegerType i32 = rewriter.getI32Type();
3593 Value tensorDimXStrideLow =
3594 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3595 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3596
3597 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3598 Value shiftVal = createI64Constant(rewriter, loc, shift);
3599 Value tensorDimXStrideHigh =
3600 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3601 tensorDimXStrideHigh =
3602 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3603 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3604 offset + shift);
3605 return {sgprY, sgprZ};
3606 }
3607
3608 std::pair<Value, Value>
3609 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3610 ConversionPatternRewriter &rewriter, Location loc,
3611 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3612 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3613 0, 160);
3614 }
3615
3616 std::pair<Value, Value>
3617 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3618 ConversionPatternRewriter &rewriter, Location loc,
3619 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3620 // Value is ignored when in gather mode.
3621 if constexpr (DescriptorOp::isGather())
3622 return {sgpr5, sgpr6};
3623 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3624 1, 208);
3625 }
3626
3627 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3628 ConversionPatternRewriter &rewriter, Location loc,
3629 ArrayRef<Value> consts) const {
3630 Value sgprs[8];
3631 for (int64_t i = 0; i < 8; ++i) {
3632 sgprs[i] = consts[0];
3633 }
3634
3635 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3636 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3637 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3638 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3639 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3640 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3641 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3642 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3643
3644 sgprs[1] =
3645 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3646 std::tie(sgprs[1], sgprs[2]) =
3647 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3648 std::tie(sgprs[2], sgprs[3]) =
3649 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3650
3651 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3652 sgprs[4] =
3653 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3654 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3655 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3656 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3657 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3658 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3659
3660 IntegerType i32 = rewriter.getI32Type();
3661 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3662 assert(v8i32 && "expected type conversion to succeed");
3663 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3664
3665 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3666 dgroup1 =
3667 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3668 }
3669
3670 return dgroup1;
3671 }
3672
3673 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3674 ConversionPatternRewriter &rewriter, Location loc,
3675 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3676 int64_t offset) const {
3677 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3678 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3679 SmallVector<OpFoldResult> mixedGlobalSizes =
3680 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3681 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3682 return sgpr0;
3683
3684 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3685 Value tensorDimX;
3686 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3687 tensorDimX =
3688 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3689 } else {
3690 IntegerType i32 = rewriter.getI32Type();
3691 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3692 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3693 }
3694
3695 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3696 }
3697
3698 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3699 ConversionPatternRewriter &rewriter, Location loc,
3700 Value sgpr0, ArrayRef<Value> consts) const {
3701 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3702 }
3703
3704 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3705 Location loc, Value accumulator,
3706 Value value, int64_t shift) const {
3707
3708 IntegerType i32 = rewriter.getI32Type();
3709 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3710 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3711 }
3712
3713 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3714 ConversionPatternRewriter &rewriter, Location loc,
3715 Value sgpr1, ArrayRef<Value> consts,
3716 int64_t offset) const {
3717 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3718 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3719 }
3720
3721 std::pair<Value, Value>
3722 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3723 ConversionPatternRewriter &rewriter, Location loc,
3724 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3725 int64_t offset) const {
3726 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3727 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3728 globalAddrIncrement, offset);
3729 Value shift = createI64Constant(rewriter, loc, 32);
3730 globalAddrIncrement =
3731 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3732 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3733 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3734 globalAddrIncrement, offset + 32);
3735 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3736 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3737 return {sgpr2, sgpr3};
3738 }
3739
3740 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3741 ConversionPatternRewriter &rewriter,
3742 Location loc, Value sgpr1,
3743 ArrayRef<Value> consts) const {
3744 Value ldsIncrement = op.getLdsIncrement();
3745 constexpr int64_t dim = 3;
3746 constexpr int64_t offset = 32;
3747 if (!ldsIncrement)
3748 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3749 offset);
3750 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3751 offset);
3752 }
3753
3754 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3755 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3756 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3757 Value globalIncrement = op.getGlobalIncrement();
3758 constexpr int32_t dim = 2;
3759 constexpr int32_t offset = 64;
3760 if (!globalIncrement)
3761 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3762 consts, dim, offset);
3763 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3764 consts, offset);
3765 }
3766
3767 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3768 ConversionPatternRewriter &rewriter, Location loc,
3769 Value sgpr3, ArrayRef<Value> consts,
3770 int32_t offset) const {
3771 Value iterationCount = adaptor.getIterationCount();
3772 IntegerType i32 = rewriter.getI32Type();
3773 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3774 // TODO: validation if the value breaks the pre-condition.
3775 // If the pre-condition fails, there is a possibility of
3776 // affecting the higher bits. In a following PR implement
3777 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3778 // checked at runtime.
3779 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3780 iterationCount =
3781 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3782 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3783 }
3784
3785 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3786 ConversionPatternRewriter &rewriter,
3787 Location loc, Value sgpr3,
3788 ArrayRef<Value> consts) const {
3789 Value iterateCount = op.getIterationCount();
3790 constexpr int32_t dim = 2;
3791 constexpr int32_t offset = 112;
3792 if (!iterateCount)
3793 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3794 offset);
3795
3796 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3797 }
3798
3799 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3800 ConversionPatternRewriter &rewriter, Location loc,
3801 ArrayRef<Value> consts) const {
3802 if constexpr (DescriptorOp::isGather())
3803 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3804 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3805 }
3806
3807 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3808 ConversionPatternRewriter &rewriter, Location loc,
3809 ArrayRef<Value> consts) const {
3810 IntegerType i32 = rewriter.getI32Type();
3811 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3812 assert(v4i32 && "expected type conversion to succeed.");
3813
3814 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3815 if (onlyNeedsTwoDescriptors)
3816 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3817
3818 constexpr int64_t sgprlen = 4;
3819 Value sgprs[sgprlen];
3820 for (int i = 0; i < sgprlen; ++i)
3821 sgprs[i] = consts[0];
3822
3823 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3824 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3825 sgprs[1], consts);
3826 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3827 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3828 sgprs[3] =
3829 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3830
3831 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3832 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3833 dgroup2 =
3834 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3835
3836 return dgroup2;
3837 }
3838
3839 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3840 ConversionPatternRewriter &rewriter, Location loc,
3841 ArrayRef<Value> consts, bool firstHalf) const {
3842 IntegerType i32 = rewriter.getI32Type();
3843 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3844 assert(v4i32 && "expected type conversion to succeed.");
3845
3846 Value indices = adaptor.getIndices();
3847 auto vectorType = cast<VectorType>(indices.getType());
3848 unsigned length = vectorType.getShape().back();
3849 Type elementType = vectorType.getElementType();
3850 unsigned maxLength = elementType == i32 ? 4 : 8;
3851 int32_t offset = firstHalf ? 0 : maxLength;
3852 unsigned discountedLength =
3853 std::max(static_cast<int32_t>(length - offset), 0);
3854
3855 unsigned targetSize = std::min(maxLength, discountedLength);
3856
3857 SmallVector<Value> indicesVector;
3858 for (unsigned i = offset; i < targetSize + offset; ++i) {
3859 Value idx;
3860 if (i < consts.size())
3861 idx = consts[i];
3862 else
3863 idx = createI32Constant(rewriter, loc, i);
3864 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
3865 indicesVector.push_back(elem);
3866 }
3867
3868 SmallVector<Value> indicesI32Vector;
3869 if (elementType == i32) {
3870 indicesI32Vector = indicesVector;
3871 } else {
3872 for (unsigned i = 0; i < targetSize; ++i) {
3873 Value index = indicesVector[i];
3874 indicesI32Vector.push_back(
3875 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3876 }
3877 if ((targetSize % 2) != 0)
3878 // Add padding when not divisible by two.
3879 indicesI32Vector.push_back(consts[0]);
3880 }
3881
3882 SmallVector<Value> indicesToInsert;
3883 if (elementType == i32) {
3884 indicesToInsert = indicesI32Vector;
3885 } else {
3886 unsigned size = indicesI32Vector.size() / 2;
3887 for (unsigned i = 0; i < size; ++i) {
3888 Value first = indicesI32Vector[2 * i];
3889 Value second = indicesI32Vector[2 * i + 1];
3890 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3891 indicesToInsert.push_back(joined);
3892 }
3893 }
3894
3895 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3896 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3897 dgroup =
3898 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3899
3900 return dgroup;
3901 }
3902
3903 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3904 ConversionPatternRewriter &rewriter, Location loc,
3905 ArrayRef<Value> consts) const {
3906 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
3907 }
3908
3909 std::pair<Value, Value>
3910 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3911 ConversionPatternRewriter &rewriter, Location loc,
3912 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
3913 constexpr int32_t dim = 3;
3914 constexpr int32_t offset = 0;
3915 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3916 dim, offset);
3917 }
3918
3919 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3920 ConversionPatternRewriter &rewriter,
3921 Location loc, Value sgpr1, Value sgpr2,
3922 ArrayRef<Value> consts) const {
3923 constexpr int32_t dim = 4;
3924 constexpr int32_t offset = 48;
3925 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3926 offset);
3927 }
3928
3929 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3930 ConversionPatternRewriter &rewriter, Location loc,
3931 Value sgpr2, ArrayRef<Value> consts) const {
3932 constexpr int32_t dim = 4;
3933 constexpr int32_t offset = 80;
3934 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3935 }
3936
3937 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3938 ConversionPatternRewriter &rewriter, Location loc,
3939 ArrayRef<Value> consts) const {
3940 if constexpr (DescriptorOp::isGather())
3941 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3942 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3943 }
3944
3945 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3946 ConversionPatternRewriter &rewriter, Location loc,
3947 ArrayRef<Value> consts) const {
3948 IntegerType i32 = rewriter.getI32Type();
3949 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3950 assert(v4i32 && "expected type conversion to succeed.");
3951 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3952 if (onlyNeedsTwoDescriptors)
3953 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3954
3955 constexpr int32_t sgprlen = 4;
3956 Value sgprs[sgprlen];
3957 for (int i = 0; i < sgprlen; ++i)
3958 sgprs[i] = consts[0];
3959
3960 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3961 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3962 std::tie(sgprs[1], sgprs[2]) =
3963 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3964 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3965
3966 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3967 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
3968 dgroup3 =
3969 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3970
3971 return dgroup3;
3972 }
3973
3974 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3975 ConversionPatternRewriter &rewriter, Location loc,
3976 ArrayRef<Value> consts) const {
3977 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
3978 }
3979
3980 LogicalResult
3981 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3982 ConversionPatternRewriter &rewriter) const override {
3983 if (chipset < kGfx1250)
3984 return op->emitOpError(
3985 "make_dma_descriptor is only supported on gfx1250");
3986
3987 Location loc = op.getLoc();
3988
3989 SmallVector<Value> consts;
3990 for (int64_t i = 0; i < 8; ++i)
3991 consts.push_back(createI32Constant(rewriter, loc, i));
3992
3993 Value dgroup0 = this->getDGroup0(adaptor);
3994 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3995 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3996 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3997 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3998 rewriter.replaceOpWithMultiple(op, {results});
3999 return success();
4000 }
4001};
4002
4003template <typename SourceOp, typename TargetOp>
4004struct AMDGPUTensorLoadStoreOpLowering
4005 : public ConvertOpToLLVMPattern<SourceOp> {
4006 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4008 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
4009 Chipset chipset)
4010 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4011 Chipset chipset;
4012
4013 LogicalResult
4014 matchAndRewrite(SourceOp op, Adaptor adaptor,
4015 ConversionPatternRewriter &rewriter) const override {
4016 if (chipset < kGfx1250)
4017 return op->emitOpError("is only supported on gfx1250");
4018
4019 ValueRange desc = adaptor.getDesc();
4020 // Create a <v8 x i32> 0 as the fifth argument to match llvm intrinsic. It
4021 // will move into the TDM descriptor once it becomes relevant for future use
4022 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4023 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4024 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4025 desc[3], dgroup4, /*cachePolicy=*/0,
4026 /*alias_scopes=*/nullptr,
4027 /*noalias_scopes=*/nullptr,
4028 /*tbaa=*/nullptr);
4029 return success();
4030 }
4031};
4032
4033struct GlobalPrefetchOpLowering
4034 : public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4035 GlobalPrefetchOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
4036 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4037
4038 LogicalResult
4039 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4040 ConversionPatternRewriter &rewriter) const override {
4041 if (chipset < kGfx1250)
4042 return op->emitOpError("is only supported on gfx1250+");
4043
4044 const bool isSpeculative = op.getSpeculative();
4045 const int32_t immArgValue = getGlobalPrefetchLLVMEncoding(
4046 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4047 IntegerAttr immArgAttr = rewriter.getI32IntegerAttr(immArgValue);
4048
4049 ValueRange indices = adaptor.getIndices();
4050 Value memRef = adaptor.getSrc();
4051 MemRefDescriptor descriptor(memRef);
4052 MemRefType memRefType = op.getSrc().getType();
4053 Location loc = op->getLoc();
4054 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4055 : LLVM::GEPNoWrapFlags::inbounds |
4056 LLVM::GEPNoWrapFlags::nuw;
4057 Value prefetchPtr = getStridedElementPtr(
4058 rewriter, loc, memRefType, descriptor, indices, inboundsFlags);
4059
4060 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4061 op, prefetchPtr, immArgAttr, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4062 mlir::ArrayAttr{});
4063 return success();
4064 }
4065
4066private:
4067 Chipset chipset;
4068};
4069
4070struct ConvertAMDGPUToROCDLPass
4071 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4072 using Base::Base;
4073
4074 void runOnOperation() override {
4075 MLIRContext *ctx = &getContext();
4076 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
4077 if (failed(maybeChipset)) {
4078 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
4079 return signalPassFailure();
4080 }
4081
4082 RewritePatternSet patterns(ctx);
4083 LLVMTypeConverter converter(ctx);
4084
4085 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
4086 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4087 LLVMConversionTarget target(getContext());
4088 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4089 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4090 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4091 if (failed(applyPartialConversion(getOperation(), target,
4092 std::move(patterns))))
4093 signalPassFailure();
4094 }
4095};
4096} // namespace
4097
4099 TypeConverter &typeConverter) {
4101 typeConverter, [](gpu::AddressSpace space) {
4102 switch (space) {
4103 case gpu::AddressSpace::Global:
4104 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4105 case gpu::AddressSpace::Workgroup:
4106 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4107 case gpu::AddressSpace::Private:
4108 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4109 case gpu::AddressSpace::Constant:
4110 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4111 }
4112 llvm_unreachable("unknown address space enum value");
4113 });
4114}
4115
4117 TypeConverter &typeConverter) {
4118 typeConverter.addTypeAttributeConversion(
4119 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
4120 -> TypeConverter::AttributeConversionResult {
4121 MLIRContext *ctx = as.getContext();
4122 Type i64 = IntegerType::get(ctx, 64);
4123 switch (as.getValue()) {
4124 case amdgpu::AddressSpace::FatRawBuffer:
4125 return IntegerAttr::get(i64, 7);
4126 case amdgpu::AddressSpace::BufferRsrc:
4127 return IntegerAttr::get(i64, 8);
4128 case amdgpu::AddressSpace::FatStructuredBuffer:
4129 return IntegerAttr::get(i64, 9);
4130 }
4131 return TypeConverter::AttributeConversionResult::abort();
4132 });
4133 typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
4134 return IntegerType::get(type.getContext(), 64);
4135 });
4136 typeConverter.addConversion([&](TDMBaseType type) -> Type {
4137 Type i32 = IntegerType::get(type.getContext(), 32);
4138 return typeConverter.convertType(VectorType::get(4, i32));
4139 });
4140 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
4141 Type i32 = IntegerType::get(type.getContext(), 32);
4142 return typeConverter.convertType(VectorType::get(4, i32));
4143 });
4144 typeConverter.addConversion(
4145 [&](TDMDescriptorType type,
4146 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
4147 Type i32 = IntegerType::get(type.getContext(), 32);
4148 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4149 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4150 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
4151 return success();
4152 });
4153
4154 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
4155 ValueRange inputs,
4157 // Only create unrealized_conversion_cast for TDMDescriptorType.
4158 // All other types which are not expected, should be
4159 // materialized by other target materialization functions.
4160 if (inputs.size() != 1)
4161 return {};
4162
4163 if (!isa<TDMDescriptorType>(inputs[0].getType()))
4164 return {};
4165
4166 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4167 return cast.getResults();
4168 };
4169
4170 typeConverter.addTargetMaterialization(addUnrealizedCast);
4171}
4172
4174 RewritePatternSet &patterns,
4175 Chipset chipset) {
4177 patterns
4178 .add<FatRawBufferCastLowering,
4179 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4180 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4181 RawBufferOpLowering<RawBufferAtomicFaddOp,
4182 ROCDL::RawPtrBufferAtomicFaddOp>,
4183 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4184 ROCDL::RawPtrBufferAtomicFmaxOp>,
4185 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4186 ROCDL::RawPtrBufferAtomicSmaxOp>,
4187 RawBufferOpLowering<RawBufferAtomicUminOp,
4188 ROCDL::RawPtrBufferAtomicUminOp>,
4189 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4190 ROCDL::RawPtrBufferAtomicCmpSwap>,
4191 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4192 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4193 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4194 SparseWMMAOpLowering, ExtPackedFp8OpLowering,
4195 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4196 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4197 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4198 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4199 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4200 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4201 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4202 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4203 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4204 ROCDL::TensorLoadToLDSOp>,
4205 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4206 ROCDL::TensorStoreFromLDSOp>,
4207 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4208 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4209 GlobalPrefetchOpLowering>(converter, chipset);
4210 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4211 DsBarrierStatePendingCountOpLowering,
4212 DsBarrierStateInitCountOpLowering,
4213 DsBarrierStatePhaseParityLowering>(converter);
4214}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL types.
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:229
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
Definition AMDGPUEnums.h:17
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14