MLIR  22.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
22 
23 #include "../LLVMCommon/MemRefDescriptor.h"
24 
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include <optional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::amdgpu;
38 
39 // Define commonly used chipsets versions for convenience.
40 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
41 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
42 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
43 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
44 
45 /// Convert an unsigned number `val` to i32.
47  Location loc, Value val) {
48  IntegerType i32 = rewriter.getI32Type();
49  // Force check that `val` is of int type.
50  auto valTy = cast<IntegerType>(val.getType());
51  if (i32 == valTy)
52  return val;
53  return valTy.getWidth() > 32
54  ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
55  : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
56 }
57 
59  Location loc, int32_t value) {
60  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
61 }
62 
63 /// Convert an unsigned number `val` to i64.
65  Location loc, Value val) {
66  IntegerType i64 = rewriter.getI64Type();
67  // Force check that `val` is of int type.
68  auto valTy = cast<IntegerType>(val.getType());
69  if (i64 == valTy)
70  return val;
71  return valTy.getWidth() > 64
72  ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
73  : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
74 }
75 
77  Location loc, int64_t value) {
78  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
79 }
80 
82  bool value) {
83  Type llvmI1 = rewriter.getI1Type();
84  return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
85 }
86 
87 /// Returns the linear index used to access an element in the memref.
89  Location loc, MemRefDescriptor &memRefDescriptor,
90  ValueRange indices, ArrayRef<int64_t> strides) {
91  IntegerType i32 = rewriter.getI32Type();
92  Value index;
93  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
94  if (stride != 1) { // Skip if stride is 1.
95  Value strideValue =
96  ShapedType::isDynamic(stride)
97  ? convertUnsignedToI32(rewriter, loc,
98  memRefDescriptor.stride(rewriter, loc, i))
99  : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
100  increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
101  }
102  index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
103  : increment;
104  }
105  return index ? index : createI32Constant(rewriter, loc, 0);
106 }
107 
108 /// Compute the contents of the `num_records` field for a given memref
109 /// descriptor - that is, the number of bytes that's one element past the
110 /// greatest possible valid index into the memref.
112  MemRefType memrefType,
113  MemRefDescriptor &memrefDescriptor,
114  ArrayRef<int64_t> strides,
115  int64_t elementByteWidth) {
116  if (memrefType.hasStaticShape() &&
117  !llvm::any_of(strides, ShapedType::isDynamic)) {
118  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
119  ArrayRef<int64_t> shape = memrefType.getShape();
120  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
121  size = std::max(shape[i] * strides[i], size);
122  size = size * elementByteWidth;
123  return createI64Constant(rewriter, loc, size);
124  }
125  Value maxIndex;
126  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
127  Value size = memrefDescriptor.size(rewriter, loc, i);
128  Value stride = memrefDescriptor.stride(rewriter, loc, i);
129  Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
130  maxIndex = maxIndex
131  ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
132  : maxThisDim;
133  }
134  Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
135  Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
136  return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
137 }
138 
140  Value basePointer, Value numRecords,
141  bool boundsCheck, amdgpu::Chipset chipset,
142  Value cacheSwizzleStride = nullptr,
143  unsigned addressSpace = 8) {
144  // The stride value is generally 0. However, on MI-300 and onward, you can
145  // enable a cache swizzling mode by setting bit 14 of the stride field
146  // and setting that stride to a cache stride.
147  Type i16 = rewriter.getI16Type();
148  Value stride;
149  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
150  Value cacheStrideZext =
151  LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
152  Value swizzleBit = LLVM::ConstantOp::create(
153  rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
154  stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
155  /*isDisjoint=*/true);
156  } else {
157  stride = LLVM::ConstantOp::create(rewriter, loc, i16,
158  rewriter.getI16IntegerAttr(0));
159  }
160  // Get the number of elements.
161  // Flag word:
162  // bits 0-11: dst sel, ignored by these intrinsics
163  // bits 12-14: data format (ignored, must be nonzero, 7=float)
164  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
165  // bit 19: In nested heap (0 here)
166  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
167  // bits 21-22: Index stride for swizzles (N/A)
168  // bit 23: Add thread ID (0)
169  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
170  // bits 25-26: Reserved (0)
171  // bit 27: Buffer is non-volatile (CDNA only)
172  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
173  // none, 3 = either swizzles or testing against offset field) RDNA only
174  // bits 30-31: Type (must be 0)
175  uint32_t flags = (7 << 12) | (4 << 15);
176  if (chipset.majorVersion >= 10) {
177  flags |= (1 << 24);
178  uint32_t oob = boundsCheck ? 3 : 2;
179  flags |= (oob << 28);
180  }
181  Value flagsConst = createI32Constant(rewriter, loc, flags);
182  Type rsrcType =
183  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
184  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
185  loc, rsrcType, basePointer, stride, numRecords, flagsConst);
186  return resource;
187 }
188 
189 namespace {
190 struct FatRawBufferCastLowering
191  : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
192  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
193  : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
194  chipset(chipset) {}
195 
196  Chipset chipset;
197 
198  LogicalResult
199  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
200  ConversionPatternRewriter &rewriter) const override {
201  Location loc = op.getLoc();
202  Value memRef = adaptor.getSource();
203  Value unconvertedMemref = op.getSource();
204  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
205  MemRefDescriptor descriptor(memRef);
206 
207  DataLayout dataLayout = DataLayout::closest(op);
208  int64_t elementByteWidth =
209  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
210 
211  int64_t unusedOffset = 0;
212  SmallVector<int64_t, 5> strideVals;
213  if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
214  return op.emitOpError("Can't lower non-stride-offset memrefs");
215 
216  Value numRecords = adaptor.getValidBytes();
217  if (!numRecords)
218  numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
219  strideVals, elementByteWidth);
220 
221  Value basePointer =
222  adaptor.getResetOffset()
223  ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
224  memrefType)
225  : descriptor.alignedPtr(rewriter, loc);
226 
227  Value offset = adaptor.getResetOffset()
228  ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
229  rewriter.getIndexAttr(0))
230  : descriptor.offset(rewriter, loc);
231 
232  bool hasSizes = memrefType.getRank() > 0;
233  // No need to unpack() and pack() all the individual sizes and strides,
234  // so we'll just extract the arrays.
235  Value sizes = hasSizes
236  ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
238  : Value{};
239  Value strides =
240  hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
242  : Value{};
243 
244  Value fatPtr = makeBufferRsrc(
245  rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
246  chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
247 
249  rewriter, loc,
250  getTypeConverter()->convertType(op.getResult().getType()));
252  result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
253  result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
255  result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
257  if (hasSizes) {
258  result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
260  result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
262  }
263  rewriter.replaceOp(op, result);
264  return success();
265  }
266 };
267 
268 /// Define lowering patterns for raw buffer ops
269 template <typename GpuOp, typename Intrinsic>
270 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
271  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
272  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
273 
274  Chipset chipset;
275  static constexpr uint32_t maxVectorOpWidth = 128;
276 
277  LogicalResult
278  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
279  ConversionPatternRewriter &rewriter) const override {
280  Location loc = gpuOp.getLoc();
281  Value memref = adaptor.getMemref();
282  Value unconvertedMemref = gpuOp.getMemref();
283  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
284 
285  if (chipset.majorVersion < 9)
286  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
287 
288  Value storeData = adaptor.getODSOperands(0)[0];
289  if (storeData == memref) // no write component to this op
290  storeData = Value();
291  Type wantedDataType;
292  if (storeData)
293  wantedDataType = storeData.getType();
294  else
295  wantedDataType = gpuOp.getODSResults(0)[0].getType();
296 
297  Value atomicCmpData = Value();
298  // Operand index 1 of a load is the indices, trying to read them can crash.
299  if (storeData) {
300  Value maybeCmpData = adaptor.getODSOperands(1)[0];
301  if (maybeCmpData != memref)
302  atomicCmpData = maybeCmpData;
303  }
304 
305  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
306 
307  Type i32 = rewriter.getI32Type();
308 
309  // Get the type size in bytes.
310  DataLayout dataLayout = DataLayout::closest(gpuOp);
311  int64_t elementByteWidth =
312  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
313  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
314 
315  // If we want to load a vector<NxT> with total size <= 32
316  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
317  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
318  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
319  // so bitcast any floats to integers.
320  Type llvmBufferValType = llvmWantedDataType;
321  if (atomicCmpData) {
322  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
323  llvmBufferValType = this->getTypeConverter()->convertType(
324  rewriter.getIntegerType(floatType.getWidth()));
325  }
326  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
327  uint32_t vecLen = dataVector.getNumElements();
328  uint32_t elemBits =
329  dataLayout.getTypeSizeInBits(dataVector.getElementType());
330  uint32_t totalBits = elemBits * vecLen;
331  bool usePackedFp16 =
332  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
333  if (totalBits > maxVectorOpWidth)
334  return gpuOp.emitOpError(
335  "Total width of loads or stores must be no more than " +
336  Twine(maxVectorOpWidth) + " bits, but we call for " +
337  Twine(totalBits) +
338  " bits. This should've been caught in validation");
339  if (!usePackedFp16 && elemBits < 32) {
340  if (totalBits > 32) {
341  if (totalBits % 32 != 0)
342  return gpuOp.emitOpError("Load or store of more than 32-bits that "
343  "doesn't fit into words. Can't happen\n");
344  llvmBufferValType = this->typeConverter->convertType(
345  VectorType::get(totalBits / 32, i32));
346  } else {
347  llvmBufferValType = this->typeConverter->convertType(
348  rewriter.getIntegerType(totalBits));
349  }
350  }
351  }
352  if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
353  // Buffer intrinsics doesn't support 1-element vectors, cast them to
354  // scalars.
355  if (vecType.getNumElements() == 1)
356  llvmBufferValType = vecType.getElementType();
357  }
358 
360  if (storeData) {
361  if (llvmBufferValType != llvmWantedDataType) {
362  Value castForStore = LLVM::BitcastOp::create(
363  rewriter, loc, llvmBufferValType, storeData);
364  args.push_back(castForStore);
365  } else {
366  args.push_back(storeData);
367  }
368  }
369 
370  if (atomicCmpData) {
371  if (llvmBufferValType != llvmWantedDataType) {
372  Value castForCmp = LLVM::BitcastOp::create(
373  rewriter, loc, llvmBufferValType, atomicCmpData);
374  args.push_back(castForCmp);
375  } else {
376  args.push_back(atomicCmpData);
377  }
378  }
379 
380  // Construct buffer descriptor from memref, attributes
381  int64_t offset = 0;
382  SmallVector<int64_t, 5> strides;
383  if (failed(memrefType.getStridesAndOffset(strides, offset)))
384  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
385 
386  MemRefDescriptor memrefDescriptor(memref);
387 
388  Value ptr = memrefDescriptor.bufferPtr(
389  rewriter, loc, *this->getTypeConverter(), memrefType);
390  Value numRecords = getNumRecords(
391  rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
392  Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
393  adaptor.getBoundsCheck(), chipset);
394  args.push_back(resource);
395 
396  // Indexing (voffset)
397  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
398  adaptor.getIndices(), strides);
399  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
400  indexOffset && *indexOffset > 0) {
401  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
402  voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
403  extraOffsetConst)
404  : extraOffsetConst;
405  }
406  voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
407  args.push_back(voffset);
408 
409  // SGPR offset.
410  Value sgprOffset = adaptor.getSgprOffset();
411  if (!sgprOffset)
412  sgprOffset = createI32Constant(rewriter, loc, 0);
413  sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
414  args.push_back(sgprOffset);
415 
416  // bit 0: GLC = 0 (atomics drop value, less coherency)
417  // bits 1-2: SLC, DLC = 0 (similarly)
418  // bit 3: swizzled (0 for raw)
419  args.push_back(createI32Constant(rewriter, loc, 0));
420 
421  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
422  llvmBufferValType);
423  Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
425  if (lowered->getNumResults() == 1) {
426  Value replacement = lowered->getResult(0);
427  if (llvmBufferValType != llvmWantedDataType) {
428  replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
429  replacement);
430  }
431  rewriter.replaceOp(gpuOp, replacement);
432  } else {
433  rewriter.eraseOp(gpuOp);
434  }
435  return success();
436  }
437 };
438 
439 // TODO: AMDGPU backend already have all this bitpacking logic, we should move
440 // it to some common place.
441 /// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
442 /// Vmcnt = Waitcnt[3:0] (pre-gfx9)
443 /// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
444 /// Vmcnt = Waitcnt[15:10] (gfx11)
445 /// Expcnt = Waitcnt[6:4] (pre-gfx11)
446 /// Expcnt = Waitcnt[2:0] (gfx11)
447 /// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
448 /// Lgkmcnt = Waitcnt[13:8] (gfx10)
449 /// Lgkmcnt = Waitcnt[9:4] (gfx11)
450 static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
451  unsigned expcnt, unsigned lgkmcnt) {
452  if (chipset.majorVersion < 9) {
453  vmcnt = std::min(15u, vmcnt);
454  expcnt = std::min(7u, expcnt);
455  lgkmcnt = std::min(15u, lgkmcnt);
456  return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
457  }
458  if (chipset.majorVersion == 9) {
459  vmcnt = std::min(63u, vmcnt);
460  expcnt = std::min(7u, expcnt);
461  lgkmcnt = std::min(15u, lgkmcnt);
462  unsigned lowBits = vmcnt & 0xF;
463  unsigned highBits = (vmcnt >> 4) << 14;
464  unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
465  return lowBits | highBits | otherCnts;
466  }
467  if (chipset.majorVersion == 10) {
468  vmcnt = std::min(63u, vmcnt);
469  expcnt = std::min(7u, expcnt);
470  lgkmcnt = std::min(63u, lgkmcnt);
471  unsigned lowBits = vmcnt & 0xF;
472  unsigned highBits = (vmcnt >> 4) << 14;
473  unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
474  return lowBits | highBits | otherCnts;
475  }
476  if (chipset.majorVersion == 11) {
477  vmcnt = std::min(63u, vmcnt);
478  expcnt = std::min(7u, expcnt);
479  lgkmcnt = std::min(63u, lgkmcnt);
480  return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
481  }
482  return failure();
483 }
484 
485 struct MemoryCounterWaitOpLowering
486  : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
487  MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
488  Chipset chipset)
489  : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
490  chipset(chipset) {}
491 
492  Chipset chipset;
493 
494  LogicalResult
495  matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
496  ConversionPatternRewriter &rewriter) const override {
497  if (chipset.majorVersion >= 12) {
498  Location loc = op.getLoc();
499  if (std::optional<int> ds = adaptor.getDs())
500  ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
501 
502  if (std::optional<int> load = adaptor.getLoad())
503  ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
504 
505  if (std::optional<int> store = adaptor.getStore())
506  ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
507 
508  if (std::optional<int> exp = adaptor.getExp())
509  ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
510 
511  rewriter.eraseOp(op);
512  return success();
513  }
514 
515  auto getVal = [](Attribute attr) -> unsigned {
516  if (attr)
517  return cast<IntegerAttr>(attr).getInt();
518 
519  // This value will be clamped to the maximum value for the chipset.
520  return 1024;
521  };
522  unsigned ds = getVal(adaptor.getDsAttr());
523  unsigned exp = getVal(adaptor.getExpAttr());
524 
525  unsigned vmcnt = 1024;
526  Attribute load = adaptor.getLoadAttr();
527  Attribute store = adaptor.getStoreAttr();
528  if (load && store) {
529  vmcnt = getVal(load) + getVal(store);
530  } else if (load) {
531  vmcnt = getVal(load);
532  } else if (store) {
533  vmcnt = getVal(store);
534  }
535 
536  FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
537  if (failed(waitcnt))
538  return op.emitOpError("unsupported chipset");
539 
540  rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
541  return success();
542  }
543 };
544 
545 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
546  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
547  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
548 
549  Chipset chipset;
550 
551  LogicalResult
552  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
553  ConversionPatternRewriter &rewriter) const override {
554  Location loc = op.getLoc();
555  // This ensures that waits on global memory aren't introduced on
556  // chips that don't have the BackOffBarrier feature enabled in LLVM.
557  bool requiresInlineAsm = chipset < kGfx90a;
558 
559  Attribute mmra =
560  rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
561  // Note: while there *is* a workgroup-one-as scope, this, when combined with
562  // the MMRA, will lead to the fence having no effect. This is because the
563  // codepaths for an atomic load or store will observe that a
564  // one-address-space atomic to LDS requires no synchronization because
565  // operations on LDS are totally ordered with respect to each other, and so
566  // will not emit the correct waitcnt operations that these fences are
567  // intended to produce. Therefore, we use a broader type of fence and rely
568  // on the MMRA to relax it to the semantics we want.
569  StringRef scope = "workgroup";
570 
571  auto relFence = LLVM::FenceOp::create(rewriter, loc,
572  LLVM::AtomicOrdering::release, scope);
573  relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
574  if (requiresInlineAsm) {
575  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
576  LLVM::AsmDialect::AD_ATT);
577  const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
578  const char *constraints = "";
579  LLVM::InlineAsmOp::create(
580  rewriter, loc,
581  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
582  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
583  /*is_align_stack=*/false, LLVM::TailCallKind::None,
584  /*asm_dialect=*/asmDialectAttr,
585  /*operand_attrs=*/ArrayAttr());
586  } else if (chipset.majorVersion < 12) {
587  ROCDL::SBarrierOp::create(rewriter, loc);
588  } else {
589  ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
590  ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
591  }
592 
593  auto acqFence = LLVM::FenceOp::create(rewriter, loc,
594  LLVM::AtomicOrdering::acquire, scope);
595  acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
596  rewriter.replaceOp(op, acqFence);
597  return success();
598  }
599 };
600 
601 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
602  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
603  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
604 
605  Chipset chipset;
606 
607  LogicalResult
608  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
609  ConversionPatternRewriter &rewriter) const override {
610  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
611  (uint32_t)op.getOpts());
612  return success();
613  }
614 };
615 
616 } // namespace
617 
618 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
619 /// and LLVM AMDGPU intrinsics convention.
620 ///
621 /// Specifically:
622 /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
623 /// allows bf16. Newer MFMAs support bf16 types on operand, check
624 /// IntrinsicsAMDGPU.td file for reference.
625 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
626 /// instead, which is what the f8f6f4 intrinsics use.
627 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
628 /// integer.
629 ///
630 /// Note that the type of `input` has already been LLVM type converted:
631 /// therefore 8-bit and smaller floats are represented as their corresponding
632 /// `iN` integers.
634  Location loc, Value input,
635  bool allowBf16 = true) {
636  Type inputType = input.getType();
637  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
638  if (vectorType.getElementType().isBF16() && !allowBf16)
639  return LLVM::BitcastOp::create(
640  rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
641  if (vectorType.getElementType().isInteger(8) &&
642  vectorType.getNumElements() <= 8)
643  return LLVM::BitcastOp::create(
644  rewriter, loc,
645  rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
646  if (isa<IntegerType>(vectorType.getElementType()) &&
647  vectorType.getElementTypeBitWidth() <= 8) {
648  int64_t numWords = llvm::divideCeil(
649  vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
650  32);
651  return LLVM::BitcastOp::create(
652  rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
653  input);
654  }
655  }
656  return input;
657 }
658 
659 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
660 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
661 ///
662 /// Specifically:
663 /// 1. If `input` is a i8 value, zero extend it to i32
664 /// 2. If `input` is a vector of length 4 and type i8, cast it to i32
665 ///
666 /// Note that the type of `input` has already been LLVM type converted:
667 /// therefore 8-bit and smaller floats are represented as their corresponding
668 /// `iN` integers.
670  Location loc, Value input) {
671  Type inputType = input.getType();
672  Type outputType = rewriter.getI32Type();
673  if (auto intType = dyn_cast<IntegerType>(inputType))
674  return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
675  return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
676 }
677 
678 /// Push an input operand. If it is a float type, nothing to do. If it is
679 /// an integer type, then we need to also push its signdness (1 for signed, 0
680 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
681 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
682 /// We also need to convert bfloat inputs to i16 to account for the bfloat
683 /// intrinsics having been defined before the AMD backend supported bfloat. We
684 /// similarly need to pack 8-bit float types into integers as if they were i8
685 /// (which they are for the backend's purposes).
687  Location loc,
688  const TypeConverter *typeConverter,
689  bool isUnsigned, Value llvmInput,
690  Value mlirInput,
691  SmallVector<Value, 4> &operands) {
692  Type inputType = llvmInput.getType();
693  auto vectorType = dyn_cast<VectorType>(inputType);
694  if (!vectorType) {
695  operands.push_back(llvmInput);
696  return;
697  }
698  Type elemType = vectorType.getElementType();
699 
700  if (elemType.isBF16())
701  llvmInput = LLVM::BitcastOp::create(
702  rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
703  if (elemType.getIntOrFloatBitWidth() > 8) {
704  operands.push_back(llvmInput);
705  return;
706  }
707 
708  // We need to check the type of the input before conversion to properly test
709  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
710  // fp8/int8 information is lost during the conversion process.
711  auto mlirInputType = cast<VectorType>(mlirInput.getType());
712  bool isInputInteger = mlirInputType.getElementType().isInteger();
713  if (isInputInteger) {
714  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
715  bool localIsUnsigned = isUnsigned;
716  if (elemType.isUnsignedInteger()) {
717  localIsUnsigned = true;
718  } else if (elemType.isSignedInteger()) {
719  localIsUnsigned = false;
720  }
721  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
722  operands.push_back(sign);
723  }
724 
725  int64_t numBits =
726  vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
727  Type i32 = rewriter.getI32Type();
728  Type intrinsicInType = numBits <= 32
729  ? (Type)rewriter.getIntegerType(numBits)
730  : (Type)VectorType::get(numBits / 32, i32);
731  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
732  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
733  loc, llvmIntrinsicInType, llvmInput);
734  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
735  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
736  // Add in the zeros here.
737  if (numBits < 32)
738  castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
739  operands.push_back(castInput);
740 }
741 
742 /// Push the output operand. For many cases this is only pushing the output in
743 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
744 /// since the same numbers of VGPRs is used, we need to decide if to store the
745 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
746 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
747 /// be stored it in the upper part. The subwordOffset must not be set for gfx12,
748 /// as the instructions have been changed to return fewer registers instead.
750  Location loc,
751  const TypeConverter *typeConverter,
752  Value output, int32_t subwordOffset,
753  bool clamp, SmallVector<Value, 4> &operands) {
754  Type inputType = output.getType();
755  auto vectorType = dyn_cast<VectorType>(inputType);
756  Type elemType = vectorType.getElementType();
757  if (elemType.isBF16())
758  output = LLVM::BitcastOp::create(
759  rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
760  operands.push_back(output);
761  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
762  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
763  } else if (elemType.isInteger(32)) {
764  operands.push_back(createI1Constant(rewriter, loc, clamp));
765  }
766 }
767 
768 /// Return true if `type` is the E5M2 variant of an 8-bit float that is
769 /// supported by the `_bf8` instructions on the given `chipset`.
770 static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
771  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
772  (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
773 }
774 
775 /// Return true if `type` is the E4M3FN variant of an 8-bit float that is
776 /// supported by the `_fp8` instructions on the given `chipset`.
777 static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
778  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
779  (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
780 }
781 
782 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
783 /// if one exists. This includes checking to ensure the intrinsic is supported
784 /// on the architecture you are compiling for.
785 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
786  Chipset chipset) {
787  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
788  b = mfma.getBlocks();
789  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
790  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
791 
792  if (sourceElem.isF32() && destElem.isF32()) {
793  if (mfma.getReducePrecision() && chipset >= kGfx942) {
794  if (m == 32 && n == 32 && k == 4 && b == 1)
795  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
796  if (m == 16 && n == 16 && k == 8 && b == 1)
797  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
798  }
799  if (m == 32 && n == 32 && k == 1 && b == 2)
800  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
801  if (m == 16 && n == 16 && k == 1 && b == 4)
802  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
803  if (m == 4 && n == 4 && k == 1 && b == 16)
804  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
805  if (m == 32 && n == 32 && k == 2 && b == 1)
806  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
807  if (m == 16 && n == 16 && k == 4 && b == 1)
808  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
809  }
810 
811  if (sourceElem.isF16() && destElem.isF32()) {
812  if (chipset >= kGfx950) {
813  if (m == 32 && n == 32 && k == 16 && b == 1)
814  return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
815  if (m == 16 && n == 16 && k == 32 && b == 1)
816  return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
817  }
818  if (m == 32 && n == 32 && k == 4 && b == 2)
819  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
820  if (m == 16 && n == 16 && k == 4 && b == 4)
821  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
822  if (m == 4 && n == 4 && k == 4 && b == 16)
823  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
824  if (m == 32 && n == 32 && k == 8 && b == 1)
825  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
826  if (m == 16 && n == 16 && k == 16 && b == 1)
827  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
828  }
829 
830  if (sourceElem.isBF16() && destElem.isF32()) {
831  if (chipset >= kGfx950) {
832  if (m == 32 && n == 32 && k == 16 && b == 1)
833  return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
834  if (m == 16 && n == 16 && k == 32 && b == 1)
835  return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
836  }
837  if (chipset >= kGfx90a) {
838  if (m == 32 && n == 32 && k == 4 && b == 2)
839  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
840  if (m == 16 && n == 16 && k == 4 && b == 4)
841  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
842  if (m == 4 && n == 4 && k == 4 && b == 16)
843  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
844  if (m == 32 && n == 32 && k == 8 && b == 1)
845  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
846  if (m == 16 && n == 16 && k == 16 && b == 1)
847  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
848  }
849  if (m == 32 && n == 32 && k == 2 && b == 2)
850  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
851  if (m == 16 && n == 16 && k == 2 && b == 4)
852  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
853  if (m == 4 && n == 4 && k == 2 && b == 16)
854  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
855  if (m == 32 && n == 32 && k == 4 && b == 1)
856  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
857  if (m == 16 && n == 16 && k == 8 && b == 1)
858  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
859  }
860 
861  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
862  if (chipset >= kGfx950) {
863  if (m == 32 && n == 32 && k == 32 && b == 1)
864  return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
865  if (m == 16 && n == 16 && k == 64 && b == 1)
866  return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
867  }
868  if (m == 32 && n == 32 && k == 4 && b == 2)
869  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
870  if (m == 16 && n == 16 && k == 4 && b == 4)
871  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
872  if (m == 4 && n == 4 && k == 4 && b == 16)
873  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
874  if (m == 32 && n == 32 && k == 8 && b == 1)
875  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
876  if (m == 16 && n == 16 && k == 16 && b == 1)
877  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
878  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
879  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
880  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
881  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
882  }
883 
884  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
885  if (m == 16 && n == 16 && k == 4 && b == 1)
886  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
887  if (m == 4 && n == 4 && k == 4 && b == 4)
888  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
889  }
890 
891  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
892  // Known to be correct because there are no scalar f8 instructions and
893  // because a length mismatch will have been caught by the verifier.
894  Type sourceBElem =
895  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
896  if (m == 16 && n == 16 && k == 32 && b == 1) {
897  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
898  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
899  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
900  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
901  }
902  if (m == 32 && n == 32 && k == 16 && b == 1) {
903  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
904  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
905  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
906  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
907  }
908  }
909 
910  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
911  Type sourceBElem =
912  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
913  if (m == 16 && n == 16 && k == 32 && b == 1) {
914  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
915  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
916  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
917  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
918  }
919  if (m == 32 && n == 32 && k == 16 && b == 1) {
920  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
921  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
922  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
923  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
924  }
925  }
926 
927  return std::nullopt;
928 }
929 
930 static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
932  .Case([](Float8E4M3FNType) { return 0u; })
933  .Case([](Float8E5M2Type) { return 1u; })
934  .Case([](Float6E2M3FNType) { return 2u; })
935  .Case([](Float6E3M2FNType) { return 3u; })
936  .Case([](Float4E2M1FNType) { return 4u; })
937  .Default([](Type) { return std::nullopt; });
938 }
939 
940 /// If there is a scaled MFMA instruction for the input element types `aType`
941 /// and `bType`, output type `destType`, problem size M, N, K, and B (number of
942 /// blocks) on the given `chipset`, return a tuple consisting of the
943 /// OperationName of the intrinsic and the type codes that need to be passed to
944 /// that intrinsic. Note that this is also used to implement some un-scaled
945 /// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
946 /// MFMA with a scale of 0.
947 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
948 mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
949  uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
950  aType = getElementTypeOrSelf(aType);
951  bType = getElementTypeOrSelf(bType);
952  destType = getElementTypeOrSelf(destType);
953 
954  if (chipset < kGfx950)
955  return std::nullopt;
956  if (!isa<Float32Type>(destType))
957  return std::nullopt;
958 
959  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
960  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
961  if (!aTypeCode || !bTypeCode)
962  return std::nullopt;
963 
964  if (m == 32 && n == 32 && k == 64 && b == 1)
965  return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
966  *aTypeCode, *bTypeCode};
967  if (m == 16 && n == 16 && k == 128 && b == 1)
968  return std::tuple{
969  ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
970  *bTypeCode};
971 
972  return std::nullopt;
973 }
974 
975 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
976 mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
978  mfma.getSourceA().getType(), mfma.getSourceB().getType(),
979  mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
980  mfma.getBlocks(), chipset);
981 }
982 
983 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
984 mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
985  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
986  smfma.getSourceB().getType(),
987  smfma.getDestC().getType(), smfma.getM(),
988  smfma.getN(), smfma.getK(), 1u, chipset);
989 }
990 
991 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
992 /// if one exists. This includes checking to ensure the intrinsic is supported
993 /// on the architecture you are compiling for.
994 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
995  Chipset chipset) {
996  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
997  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
998  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
999  auto elemSourceType = sourceVectorType.getElementType();
1000  auto elemBSourceType = sourceBVectorType.getElementType();
1001  auto elemDestType = destVectorType.getElementType();
1002 
1003  if (elemSourceType.isF16() && elemDestType.isF32())
1004  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1005  if (elemSourceType.isBF16() && elemDestType.isF32())
1006  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1007  if (elemSourceType.isF16() && elemDestType.isF16())
1008  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1009  if (elemSourceType.isBF16() && elemDestType.isBF16())
1010  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1011  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1012  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1013  if (chipset.majorVersion == 11) {
1014  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1015  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1016  }
1017  if (chipset.majorVersion >= 12) {
1018  if (isa<Float8E4M3FNType>(elemSourceType) &&
1019  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1020  return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1021  if (isa<Float8E4M3FNType>(elemSourceType) &&
1022  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1023  return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1024  if (isa<Float8E5M2Type>(elemSourceType) &&
1025  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1026  return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1027  if (isa<Float8E5M2Type>(elemSourceType) &&
1028  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1029  return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1030  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
1031  bool isWave64 = destVectorType.getNumElements() == 4;
1032  // This is the ambiguous case. 8 inputs to the wave64 version means that
1033  // we want the 16x16x32 version, but for wave32 they mean the short form.
1034  bool has8Inputs = sourceVectorType.getNumElements() == 8;
1035  if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
1036  return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1037  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1038  }
1039  }
1040  return std::nullopt;
1041 }
1042 
1043 namespace {
1044 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1045  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1046  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1047 
1048  Chipset chipset;
1049 
1050  LogicalResult
1051  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1052  ConversionPatternRewriter &rewriter) const override {
1053  Location loc = op.getLoc();
1054  Type outType = typeConverter->convertType(op.getDestD().getType());
1055  Type intrinsicOutType = outType;
1056  if (auto outVecType = dyn_cast<VectorType>(outType))
1057  if (outVecType.getElementType().isBF16())
1058  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1059 
1060  if (chipset.majorVersion != 9 || chipset < kGfx908)
1061  return op->emitOpError("MFMA only supported on gfx908+");
1062  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1063  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1064  if (chipset < kGfx942)
1065  return op.emitOpError("negation unsupported on older than gfx942");
1066  getBlgpField |=
1067  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1068  }
1069  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1070  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1071  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1072  if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1073  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1074 
1075  bool isScaled =
1076  !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1077  if (isScaled &&
1078  (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1079  return op.emitOpError(
1080  "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1081  "be scaled as those fields are used for type information");
1082  }
1083 
1084  StringRef intrinsicName =
1085  isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1086  // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1087  // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1088  bool allowBf16 = [&]() {
1089  if (chipset < kGfx950)
1090  return false;
1091  if (isScaled)
1092  return true;
1093  return intrinsicName.contains("16x16x32.bf16") ||
1094  intrinsicName.contains("32x32x16.bf16");
1095  }();
1096  OperationState loweredOp(loc, intrinsicName);
1097  loweredOp.addTypes(intrinsicOutType);
1098  loweredOp.addOperands({convertMFMAVectorOperand(
1099  rewriter, loc, adaptor.getSourceA(), allowBf16),
1101  rewriter, loc, adaptor.getSourceB(), allowBf16),
1102  adaptor.getDestC()});
1103  if (isScaled) {
1104  Value zero = createI32Constant(rewriter, loc, 0);
1105  auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1106  loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1107  createI32Constant(rewriter, loc, bTypeCode),
1108  /*scale A byte=*/zero, /*scale A=*/zero,
1109  /*scale B byte=*/zero, /*scale B=*/zero});
1110  } else {
1111  loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1112  createI32Constant(rewriter, loc, op.getAbid()),
1113  createI32Constant(rewriter, loc, getBlgpField)});
1114  };
1115  Value lowered = rewriter.create(loweredOp)->getResult(0);
1116  if (outType != intrinsicOutType)
1117  lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1118  rewriter.replaceOp(op, lowered);
1119  return success();
1120  }
1121 };
1122 
1123 struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1124  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1125  : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1126 
1127  Chipset chipset;
1128 
1129  LogicalResult
1130  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1131  ConversionPatternRewriter &rewriter) const override {
1132  Location loc = op.getLoc();
1133  Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1134 
1135  if (chipset.majorVersion != 9 || chipset < kGfx950)
1136  return op->emitOpError("scaled MFMA only supported on gfx908+");
1137  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1138  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1139  if (!maybeScaledIntrinsic.has_value())
1140  return op.emitOpError(
1141  "no intrinsic matching scaled MFMA size on given chipset");
1142 
1143  auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1144  OperationState loweredOp(loc, intrinsicName);
1145  loweredOp.addTypes(intrinsicOutType);
1146  loweredOp.addOperands(
1147  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1148  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1149  adaptor.getDestC()});
1150  Value scalesIdxA =
1151  createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1152  Value scalesIdxB =
1153  createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1154  loweredOp.addOperands(
1155  {createI32Constant(rewriter, loc, aTypeCode),
1156  createI32Constant(rewriter, loc, bTypeCode),
1157  /*scales idx A=*/scalesIdxA,
1158  /*scales A*/
1159  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1160  /*scales idx B=*/scalesIdxB,
1161  /*scales B*/
1162  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1163  Value lowered = rewriter.create(loweredOp)->getResult(0);
1164  rewriter.replaceOp(op, lowered);
1165  return success();
1166  }
1167 };
1168 
1169 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1170  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1171  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1172 
1173  Chipset chipset;
1174 
1175  LogicalResult
1176  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1177  ConversionPatternRewriter &rewriter) const override {
1178  Location loc = op.getLoc();
1179  auto outType =
1180  typeConverter->convertType<VectorType>(op.getDestD().getType());
1181  if (!outType)
1182  return rewriter.notifyMatchFailure(op, "type conversion failed");
1183 
1184  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1185  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1186 
1187  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1188  // need to bitcast bfloats to i16 and then bitcast them back.
1189  VectorType rawOutType = outType;
1190  if (outType.getElementType().isBF16())
1191  rawOutType = outType.clone(rewriter.getI16Type());
1192 
1193  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1194 
1195  if (!maybeIntrinsic.has_value())
1196  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1197 
1198  if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1199  return op.emitOpError("subwordOffset not supported on gfx12+");
1200 
1201  OperationState loweredOp(loc, *maybeIntrinsic);
1202  loweredOp.addTypes(rawOutType);
1203 
1204  SmallVector<Value, 4> operands;
1205  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1206  adaptor.getSourceA(), op.getSourceA(), operands);
1207  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1208  adaptor.getSourceB(), op.getSourceB(), operands);
1209  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1210  op.getSubwordOffset(), op.getClamp(), operands);
1211 
1212  loweredOp.addOperands(operands);
1213  Operation *lowered = rewriter.create(loweredOp);
1214 
1215  Operation *maybeCastBack = lowered;
1216  if (rawOutType != outType)
1217  maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1218  lowered->getResult(0));
1219  rewriter.replaceOp(op, maybeCastBack->getResults());
1220 
1221  return success();
1222  }
1223 };
1224 
1225 struct TransposeLoadOpLowering
1226  : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1227  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1228  : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1229 
1230  Chipset chipset;
1231 
1232  LogicalResult
1233  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1234  ConversionPatternRewriter &rewriter) const override {
1235  if (chipset != kGfx950)
1236  return op.emitOpError("Non-gfx950 chipset not supported");
1237 
1238  Location loc = op.getLoc();
1239  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1240 
1241  // Elements in subbyte memrefs are stored non-contiguously,
1242  // reject if source is sub-byte memref. Use emulated memrefs instead.
1243  size_t srcElementSize =
1244  srcMemRefType.getElementType().getIntOrFloatBitWidth();
1245  if (srcElementSize < 8)
1246  return op.emitOpError("Expect source memref to have at least 8 bits "
1247  "element size, got ")
1248  << srcElementSize;
1249 
1250  auto resultType = cast<VectorType>(op.getResult().getType());
1251  Value srcPtr =
1252  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1253  (adaptor.getSrcIndices()));
1254 
1255  size_t numElements = resultType.getNumElements();
1256  size_t elementTypeSize =
1257  resultType.getElementType().getIntOrFloatBitWidth();
1258 
1259  // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1260  // the element size is smaller than 16 bits.
1261  Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1262  rewriter.getIntegerType(32));
1263  Type llvmResultType = typeConverter->convertType(resultType);
1264 
1265  switch (elementTypeSize) {
1266  case 4: {
1267  assert(numElements == 16);
1268  auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1269  rocdlResultType, srcPtr);
1270  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1271  break;
1272  }
1273  case 6: {
1274  assert(numElements == 16);
1275  auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1276  rocdlResultType, srcPtr);
1277  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1278  break;
1279  }
1280  case 8: {
1281  assert(numElements == 8);
1282  auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1283  rocdlResultType, srcPtr);
1284  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1285  break;
1286  }
1287  case 16: {
1288  assert(numElements == 4);
1289  rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1290  srcPtr);
1291  break;
1292  }
1293  default:
1294  return op.emitOpError("Unsupported element size for transpose load");
1295  }
1296  return success();
1297  }
1298 };
1299 
1300 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1301  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1302  : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1303 
1304  Chipset chipset;
1305 
1306  LogicalResult
1307  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1308  ConversionPatternRewriter &rewriter) const override {
1309  if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1310  return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1311 
1312  Location loc = op.getLoc();
1313 
1314  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1315  auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1316 
1317  // TODO: instead of only transfering one element per thread, we could
1318  // augment it to transfer multiple elements per thread by issuing multiple
1319  // `global_load_lds` instructions.
1320  Type transferType = op.getTransferType();
1321  int loadWidth = [&]() -> int {
1322  if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1323  return (transferVectorType.getNumElements() *
1324  transferVectorType.getElementTypeBitWidth()) /
1325  8;
1326  }
1327  return transferType.getIntOrFloatBitWidth() / 8;
1328  }();
1329 
1330  // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1331  if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1332  return op.emitOpError("chipset unsupported element size");
1333 
1334  if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1335  return op.emitOpError("Gather to LDS instructions with 12-byte and "
1336  "16-byte load widths are only supported on gfx950");
1337 
1338  Value srcPtr =
1339  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1340  (adaptor.getSrcIndices()));
1341  Value dstPtr =
1342  getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1343  (adaptor.getDstIndices()));
1344 
1345  rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1346  op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1347  /*offset=*/rewriter.getI32IntegerAttr(0),
1348  /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1349  ArrayAttr{});
1350 
1351  return success();
1352  }
1353 };
1354 
1355 namespace {
1356 struct ExtPackedFp8OpLowering final
1357  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1358  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1359  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1360  chipset(chipset) {}
1361  Chipset chipset;
1362 
1363  LogicalResult
1364  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1365  ConversionPatternRewriter &rewriter) const override;
1366 };
1367 
1368 struct PackedTrunc2xFp8OpLowering final
1369  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1370  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1371  Chipset chipset)
1372  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1373  chipset(chipset) {}
1374  Chipset chipset;
1375 
1376  LogicalResult
1377  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1378  ConversionPatternRewriter &rewriter) const override;
1379 };
1380 
1381 struct PackedStochRoundFp8OpLowering final
1382  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1383  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1384  Chipset chipset)
1385  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1386  chipset(chipset) {}
1387  Chipset chipset;
1388 
1389  LogicalResult
1390  matchAndRewrite(PackedStochRoundFp8Op op,
1391  PackedStochRoundFp8OpAdaptor adaptor,
1392  ConversionPatternRewriter &rewriter) const override;
1393 };
1394 
1395 struct ScaledExtPackedOpLowering final
1396  : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1397  ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1398  : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1399  chipset(chipset) {}
1400  Chipset chipset;
1401 
1402  LogicalResult
1403  matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1404  ConversionPatternRewriter &rewriter) const override;
1405 };
1406 
1407 struct PackedScaledTruncOpLowering final
1408  : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1409  PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1410  Chipset chipset)
1411  : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1412  chipset(chipset) {}
1413  Chipset chipset;
1414 
1415  LogicalResult
1416  matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1417  ConversionPatternRewriter &rewriter) const override;
1418 };
1419 
1420 } // end namespace
1421 
1422 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1423  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1424  ConversionPatternRewriter &rewriter) const {
1425  Location loc = op.getLoc();
1426  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1427  return rewriter.notifyMatchFailure(
1428  loc, "Fp8 conversion instructions are not available on target "
1429  "architecture and their emulation is not implemented");
1430  Type v4i8 =
1431  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1432  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1433  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1434 
1435  Value source = adaptor.getSource();
1436  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1437  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1438  Type sourceElemType = getElementTypeOrSelf(op.getSource());
1439  // Extend to a v4i8
1440  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1441  Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1442  if (!sourceVecType) {
1443  longVec = LLVM::InsertElementOp::create(
1444  rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1445  } else {
1446  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1447  Value idx = createI32Constant(rewriter, loc, i);
1448  Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1449  longVec =
1450  LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1451  }
1452  }
1453  source = longVec;
1454  }
1455  Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1456  if (resultVecType) {
1457  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1458  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1459  op.getIndex());
1460  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1461  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1462  op.getIndex());
1463  }
1464  } else {
1465  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1466  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1467  op.getIndex());
1468  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1469  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1470  op.getIndex());
1471  }
1472  }
1473  return success();
1474 }
1475 
1476 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1477  ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1478  ConversionPatternRewriter &rewriter) const {
1479  Location loc = op.getLoc();
1480  if (chipset != kGfx950)
1481  return rewriter.notifyMatchFailure(
1482  loc, "Scaled fp conversion instructions are not available on target "
1483  "architecture and their emulation is not implemented");
1484  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1485 
1486  Value source = adaptor.getSource();
1487  Value scale = adaptor.getScale();
1488 
1489  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1490  Type sourceElemType = sourceVecType.getElementType();
1491  VectorType destVecType = cast<VectorType>(op.getResult().getType());
1492  Type destElemType = destVecType.getElementType();
1493 
1494  VectorType packedVecType;
1495  if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1496  VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1497  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1498  } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1499  VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1500  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1501  } else {
1502  llvm_unreachable("invalid element type for scaled ext");
1503  }
1504 
1505  // Extend to a packedVectorType
1506  if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1507  Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1508  if (!sourceVecType) {
1509  longVec = LLVM::InsertElementOp::create(
1510  rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1511  } else {
1512  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1513  Value idx = createI32Constant(rewriter, loc, i);
1514  Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1515  longVec =
1516  LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1517  }
1518  }
1519  source = longVec;
1520  }
1521  Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1522 
1523  if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1524  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1525  op, destVecType, i32Source, scale, op.getIndex());
1526  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1527  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1528  op, destVecType, i32Source, scale, op.getIndex());
1529  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1530  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1531  op, destVecType, i32Source, scale, op.getIndex());
1532  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1533  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1534  op, destVecType, i32Source, scale, op.getIndex());
1535  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1536  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1537  op, destVecType, i32Source, scale, op.getIndex());
1538  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1539  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1540  op, destVecType, i32Source, scale, op.getIndex());
1541  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1542  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1543  op, destVecType, i32Source, scale, op.getIndex());
1544  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1545  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1546  op, destVecType, i32Source, scale, op.getIndex());
1547  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1548  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1549  op, destVecType, i32Source, scale, op.getIndex());
1550  else
1551  return failure();
1552 
1553  return success();
1554 }
1555 
1556 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1557  PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1558  ConversionPatternRewriter &rewriter) const {
1559  Location loc = op.getLoc();
1560  if (chipset != kGfx950)
1561  return rewriter.notifyMatchFailure(
1562  loc, "Scaled fp conversion instructions are not available on target "
1563  "architecture and their emulation is not implemented");
1564  Type v2i16 = getTypeConverter()->convertType(
1565  VectorType::get(2, rewriter.getI16Type()));
1566  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1567 
1568  Type resultType = op.getResult().getType();
1569  Type resultElemType = getElementTypeOrSelf(resultType);
1570  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1571  Type sourceElemType = sourceVecType.getElementType();
1572 
1573  Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1574 
1575  Value source = adaptor.getSource();
1576  Value scale = adaptor.getScale();
1577  Value existing = adaptor.getExisting();
1578  if (existing)
1579  existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1580  else
1581  existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1582 
1583  if (sourceVecType.getNumElements() < 2) {
1584  Value c0 = createI32Constant(rewriter, loc, 0);
1585  Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1586  VectorType v2 = VectorType::get(2, sourceElemType);
1587  source = LLVM::ZeroOp::create(rewriter, loc, v2);
1588  source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1589  }
1590 
1591  Value sourceA, sourceB;
1592  if (sourceElemType.isF32()) {
1593  Value c0 = createI32Constant(rewriter, loc, 0);
1594  Value c1 = createI32Constant(rewriter, loc, 1);
1595  sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1596  sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1597  }
1598 
1599  Value result;
1600  if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1601  result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1602  existing, sourceA, sourceB,
1603  scale, op.getIndex());
1604  else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1605  result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1606  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1607  else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1608  result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1609  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1610  else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1611  result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1612  existing, sourceA, sourceB,
1613  scale, op.getIndex());
1614  else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1615  result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1616  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1617  else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1618  result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1619  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1620  else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1621  result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1622  existing, sourceA, sourceB,
1623  scale, op.getIndex());
1624  else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1625  result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1626  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1627  else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1628  result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1629  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1630  else
1631  return failure();
1632 
1633  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1634  op, getTypeConverter()->convertType(resultType), result);
1635  return success();
1636 }
1637 
1638 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1639  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1640  ConversionPatternRewriter &rewriter) const {
1641  Location loc = op.getLoc();
1642  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1643  return rewriter.notifyMatchFailure(
1644  loc, "Fp8 conversion instructions are not available on target "
1645  "architecture and their emulation is not implemented");
1646  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1647 
1648  Type resultType = op.getResult().getType();
1649  Type resultElemType = getElementTypeOrSelf(resultType);
1650 
1651  Value sourceA = adaptor.getSourceA();
1652  Value sourceB = adaptor.getSourceB();
1653  if (!sourceB)
1654  sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
1655  Value existing = adaptor.getExisting();
1656  if (existing)
1657  existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1658  else
1659  existing = LLVM::UndefOp::create(rewriter, loc, i32);
1660 
1661  Value result;
1662  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1663  result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1664  existing, op.getWordIndex());
1665  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1666  result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1667  existing, op.getWordIndex());
1668 
1669  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1670  op, getTypeConverter()->convertType(resultType), result);
1671  return success();
1672 }
1673 
1674 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1675  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1676  ConversionPatternRewriter &rewriter) const {
1677  Location loc = op.getLoc();
1678  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1679  return rewriter.notifyMatchFailure(
1680  loc, "Fp8 conversion instructions are not available on target "
1681  "architecture and their emulation is not implemented");
1682  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1683 
1684  Type resultType = op.getResult().getType();
1685  Type resultElemType = getElementTypeOrSelf(resultType);
1686 
1687  Value source = adaptor.getSource();
1688  Value stoch = adaptor.getStochiasticParam();
1689  Value existing = adaptor.getExisting();
1690  if (existing)
1691  existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1692  else
1693  existing = LLVM::UndefOp::create(rewriter, loc, i32);
1694 
1695  Value result;
1696  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1697  result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
1698  existing, op.getStoreIndex());
1699  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1700  result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
1701  existing, op.getStoreIndex());
1702 
1703  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1704  op, getTypeConverter()->convertType(resultType), result);
1705  return success();
1706 }
1707 
1708 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1709 // operation into the corresponding ROCDL instructions.
1710 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1711  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1712  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1713  Chipset chipset;
1714 
1715  LogicalResult
1716  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1717  ConversionPatternRewriter &rewriter) const override {
1718 
1719  // Convert the source operand to the corresponding LLVM type
1720  Location loc = DppOp.getLoc();
1721  Value src = adaptor.getSrc();
1722  Value old = adaptor.getOld();
1723  Type srcType = src.getType();
1724  Type oldType = old.getType();
1725  Type llvmType = nullptr;
1726  if (srcType.getIntOrFloatBitWidth() < 32) {
1727  llvmType = rewriter.getI32Type();
1728  } else if (isa<FloatType>(srcType)) {
1729  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1730  ? rewriter.getF32Type()
1731  : rewriter.getF64Type();
1732  } else if (isa<IntegerType>(srcType)) {
1733  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1734  ? rewriter.getI32Type()
1735  : rewriter.getI64Type();
1736  }
1737  auto llvmSrcIntType = typeConverter->convertType(
1738  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1739 
1740  // If the source type is less of 32, use bitcast to convert it to i32.
1741  auto convertOperand = [&](Value operand, Type operandType) {
1742  if (operandType.getIntOrFloatBitWidth() <= 16) {
1743  if (llvm::isa<FloatType>(operandType)) {
1744  operand =
1745  LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
1746  }
1747  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1748  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1749  Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
1750  operand =
1751  LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
1752  createI32Constant(rewriter, loc, 0));
1753  operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
1754  }
1755  return operand;
1756  };
1757 
1758  src = convertOperand(src, srcType);
1759  old = convertOperand(old, oldType);
1760 
1761  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1762  enum DppCtrl : unsigned {
1763  ROW_SHL0 = 0x100,
1764  ROW_SHR0 = 0x110,
1765  ROW_ROR0 = 0x120,
1766  WAVE_SHL1 = 0x130,
1767  WAVE_ROL1 = 0x134,
1768  WAVE_SHR1 = 0x138,
1769  WAVE_ROR1 = 0x13C,
1770  ROW_MIRROR = 0x140,
1771  ROW_HALF_MIRROR = 0x141,
1772  BCAST15 = 0x142,
1773  BCAST31 = 0x143,
1774  };
1775 
1776  auto kind = DppOp.getKind();
1777  auto permArgument = DppOp.getPermArgument();
1778  uint32_t DppCtrl = 0;
1779 
1780  switch (kind) {
1781 
1782  case DPPPerm::quad_perm:
1783  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1784  int32_t i = 0;
1785  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1786  uint32_t num = elem.getInt();
1787  DppCtrl |= num << (i * 2);
1788  i++;
1789  }
1790  }
1791  break;
1792  case DPPPerm::row_shl:
1793  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1794  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1795  }
1796  break;
1797  case DPPPerm::row_shr:
1798  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1799  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1800  }
1801  break;
1802  case DPPPerm::row_ror:
1803  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1804  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1805  }
1806  break;
1807  case DPPPerm::wave_shl:
1808  DppCtrl = DppCtrl::WAVE_SHL1;
1809  break;
1810  case DPPPerm::wave_shr:
1811  DppCtrl = DppCtrl::WAVE_SHR1;
1812  break;
1813  case DPPPerm::wave_rol:
1814  DppCtrl = DppCtrl::WAVE_ROL1;
1815  break;
1816  case DPPPerm::wave_ror:
1817  DppCtrl = DppCtrl::WAVE_ROR1;
1818  break;
1819  case DPPPerm::row_mirror:
1820  DppCtrl = DppCtrl::ROW_MIRROR;
1821  break;
1822  case DPPPerm::row_half_mirror:
1823  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1824  break;
1825  case DPPPerm::row_bcast_15:
1826  DppCtrl = DppCtrl::BCAST15;
1827  break;
1828  case DPPPerm::row_bcast_31:
1829  DppCtrl = DppCtrl::BCAST31;
1830  break;
1831  }
1832 
1833  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1834  // constants
1835  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1836  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1837  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1838 
1839  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1840  auto dppMovOp =
1841  ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
1842  rowMask, bankMask, boundCtrl);
1843 
1844  Value result = dppMovOp.getRes();
1845  if (srcType.getIntOrFloatBitWidth() < 32) {
1846  result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
1847  if (!llvm::isa<IntegerType>(srcType)) {
1848  result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
1849  }
1850  }
1851 
1852  // We are replacing the AMDGPU_DPPOp instruction with the new
1853  // ROCDL_DPPMovOp instruction
1854  rewriter.replaceOp(DppOp, ValueRange(result));
1855  return success();
1856  }
1857 };
1858 
1859 struct AMDGPUSwizzleBitModeLowering
1860  : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1862 
1863  LogicalResult
1864  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1865  ConversionPatternRewriter &rewriter) const override {
1866  Location loc = op.getLoc();
1867  Type i32 = rewriter.getI32Type();
1868  Value src = adaptor.getSrc();
1869  SmallVector<Value> decomposed =
1870  LLVM::decomposeValue(rewriter, loc, src, i32);
1871  unsigned andMask = op.getAndMask();
1872  unsigned orMask = op.getOrMask();
1873  unsigned xorMask = op.getXorMask();
1874 
1875  // bit 15 is 0 for the BitMode swizzle.
1876  // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1877  unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1878  Value maskValue = createI32Constant(rewriter, loc, mask);
1879  SmallVector<Value> swizzled;
1880  for (Value v : decomposed) {
1881  Value res =
1882  ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
1883  swizzled.emplace_back(res);
1884  }
1885 
1886  Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1887  rewriter.replaceOp(op, result);
1888  return success();
1889  }
1890 };
1891 
1892 struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
1894 
1895  AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
1896  : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
1897  Chipset chipset;
1898 
1899  LogicalResult
1900  matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
1901  ConversionPatternRewriter &rewriter) const override {
1902  if (chipset < kGfx950)
1903  return op->emitOpError("permlane_swap is only supported on gfx950+");
1904 
1905  Location loc = op.getLoc();
1906  Type i32 = rewriter.getI32Type();
1907  Value src = adaptor.getSrc();
1908  unsigned rowLength = op.getRowLength();
1909  bool fi = op.getFetchInactive();
1910  bool boundctrl = op.getBoundCtrl();
1911 
1912  SmallVector<Value> decomposed =
1913  LLVM::decomposeValue(rewriter, loc, src, i32);
1914 
1915  SmallVector<Value> permuted;
1916  for (Value v : decomposed) {
1917  Value res;
1918  Type i32pair = LLVM::LLVMStructType::getLiteral(
1919  rewriter.getContext(), {v.getType(), v.getType()});
1920 
1921  if (rowLength == 16)
1922  res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1923  boundctrl);
1924  else if (rowLength == 32)
1925  res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1926  boundctrl);
1927  else
1928  llvm_unreachable("unsupported row length");
1929 
1930  const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
1931  const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
1932 
1933  const Value isEqual =
1934  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
1935 
1936  // Per `permlane(16|32)` semantics: if the first extracted element equals
1937  // 'v', the result is the second element; otherwise it is the first.
1938  Value vdstNew =
1939  rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
1940  permuted.emplace_back(vdstNew);
1941  }
1942 
1943  Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
1944  rewriter.replaceOp(op, result);
1945  return success();
1946  }
1947 };
1948 
1949 struct ConvertAMDGPUToROCDLPass
1950  : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1951  using Base::Base;
1952 
1953  void runOnOperation() override {
1954  MLIRContext *ctx = &getContext();
1955  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1956  if (failed(maybeChipset)) {
1957  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1958  return signalPassFailure();
1959  }
1960 
1962  LLVMTypeConverter converter(ctx);
1963  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1964  LLVMConversionTarget target(getContext());
1965  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1966  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1967  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1968  if (failed(applyPartialConversion(getOperation(), target,
1969  std::move(patterns))))
1970  signalPassFailure();
1971  }
1972 };
1973 } // namespace
1974 
1976  TypeConverter &typeConverter) {
1977  typeConverter.addTypeAttributeConversion(
1978  [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1980  MLIRContext *ctx = as.getContext();
1981  Type i64 = IntegerType::get(ctx, 64);
1982  switch (as.getValue()) {
1983  case amdgpu::AddressSpace::FatRawBuffer:
1984  return IntegerAttr::get(i64, 7);
1985  case amdgpu::AddressSpace::BufferRsrc:
1986  return IntegerAttr::get(i64, 8);
1987  case amdgpu::AddressSpace::FatStructuredBuffer:
1988  return IntegerAttr::get(i64, 9);
1989  }
1991  });
1992 }
1993 
1996  Chipset chipset) {
1998  patterns
1999  .add<FatRawBufferCastLowering,
2000  RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2001  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2002  RawBufferOpLowering<RawBufferAtomicFaddOp,
2003  ROCDL::RawPtrBufferAtomicFaddOp>,
2004  RawBufferOpLowering<RawBufferAtomicFmaxOp,
2005  ROCDL::RawPtrBufferAtomicFmaxOp>,
2006  RawBufferOpLowering<RawBufferAtomicSmaxOp,
2007  ROCDL::RawPtrBufferAtomicSmaxOp>,
2008  RawBufferOpLowering<RawBufferAtomicUminOp,
2009  ROCDL::RawPtrBufferAtomicUminOp>,
2010  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2011  ROCDL::RawPtrBufferAtomicCmpSwap>,
2012  AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2013  SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2014  WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
2015  PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2016  PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2017  TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2018  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
2019 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1245::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
@ None
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getI16Type()
Definition: Builders.cpp:61
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:217
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IntegerType getI4Type()
Definition: Builders.cpp:57
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IntegerType getI8Type()
Definition: Builders.cpp:59
FloatType getF64Type()
Definition: Builders.cpp:45
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
The general result of a type attribute conversion callback, allowing for early termination.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:439
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:400
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14