MLIR  22.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
22 
23 #include "../LLVMCommon/MemRefDescriptor.h"
24 
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include <optional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::amdgpu;
38 
39 // Define commonly used chipsets versions for convenience.
40 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
41 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
42 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
43 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
44 
45 /// Convert an unsigned number `val` to i32.
47  Location loc, Value val) {
48  IntegerType i32 = rewriter.getI32Type();
49  // Force check that `val` is of int type.
50  auto valTy = cast<IntegerType>(val.getType());
51  if (i32 == valTy)
52  return val;
53  return valTy.getWidth() > 32
54  ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
55  : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
56 }
57 
59  Location loc, int32_t value) {
60  Type i32 = rewriter.getI32Type();
61  return LLVM::ConstantOp::create(rewriter, loc, i32, value);
62 }
63 
65  bool value) {
66  Type llvmI1 = rewriter.getI1Type();
67  return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
68 }
69 
70 /// Returns the linear index used to access an element in the memref.
72  Location loc, MemRefDescriptor &memRefDescriptor,
73  ValueRange indices, ArrayRef<int64_t> strides) {
74  IntegerType i32 = rewriter.getI32Type();
75  Value index;
76  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
77  if (stride != 1) { // Skip if stride is 1.
78  Value strideValue =
79  ShapedType::isDynamic(stride)
80  ? convertUnsignedToI32(rewriter, loc,
81  memRefDescriptor.stride(rewriter, loc, i))
82  : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
83  increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
84  }
85  index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
86  : increment;
87  }
88  return index ? index : createI32Constant(rewriter, loc, 0);
89 }
90 
91 /// Compute the contents of the `num_records` field for a given memref
92 /// descriptor - that is, the number of bytes that's one element past the
93 /// greatest possible valid index into the memref.
95  MemRefType memrefType,
96  MemRefDescriptor &memrefDescriptor,
97  ArrayRef<int64_t> strides,
98  uint32_t elementByteWidth) {
99  if (memrefType.hasStaticShape() &&
100  !llvm::any_of(strides, ShapedType::isDynamic)) {
101  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
102  ArrayRef<int64_t> shape = memrefType.getShape();
103  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
104  size = std::max(shape[i] * strides[i], size);
105  size = size * elementByteWidth;
106  assert(size < std::numeric_limits<uint32_t>::max() &&
107  "the memref buffer is too large");
108  return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
109  }
110  Value maxIndex;
111  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
112  Value size = memrefDescriptor.size(rewriter, loc, i);
113  Value stride = memrefDescriptor.stride(rewriter, loc, i);
114  Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
115  maxIndex = maxIndex
116  ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
117  : maxThisDim;
118  }
119  Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
120  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
121  return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
122 }
123 
125  Value basePointer, Value numRecords,
126  bool boundsCheck, amdgpu::Chipset chipset,
127  Value cacheSwizzleStride = nullptr,
128  unsigned addressSpace = 8) {
129  // The stride value is generally 0. However, on MI-300 and onward, you can
130  // enable a cache swizzling mode by setting bit 14 of the stride field
131  // and setting that stride to a cache stride.
132  Type i16 = rewriter.getI16Type();
133  Value stride;
134  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
135  Value cacheStrideZext =
136  LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
137  Value swizzleBit = LLVM::ConstantOp::create(
138  rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
139  stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
140  /*isDisjoint=*/true);
141  } else {
142  stride = LLVM::ConstantOp::create(rewriter, loc, i16,
143  rewriter.getI16IntegerAttr(0));
144  }
145  // Get the number of elements.
146  // Flag word:
147  // bits 0-11: dst sel, ignored by these intrinsics
148  // bits 12-14: data format (ignored, must be nonzero, 7=float)
149  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
150  // bit 19: In nested heap (0 here)
151  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
152  // bits 21-22: Index stride for swizzles (N/A)
153  // bit 23: Add thread ID (0)
154  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
155  // bits 25-26: Reserved (0)
156  // bit 27: Buffer is non-volatile (CDNA only)
157  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
158  // none, 3 = either swizzles or testing against offset field) RDNA only
159  // bits 30-31: Type (must be 0)
160  uint32_t flags = (7 << 12) | (4 << 15);
161  if (chipset.majorVersion >= 10) {
162  flags |= (1 << 24);
163  uint32_t oob = boundsCheck ? 3 : 2;
164  flags |= (oob << 28);
165  }
166  Value flagsConst = createI32Constant(rewriter, loc, flags);
167  Type rsrcType =
168  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
169  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
170  loc, rsrcType, basePointer, stride, numRecords, flagsConst);
171  return resource;
172 }
173 
174 namespace {
175 struct FatRawBufferCastLowering
176  : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
177  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
178  : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
179  chipset(chipset) {}
180 
181  Chipset chipset;
182 
183  LogicalResult
184  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186  Location loc = op.getLoc();
187  Value memRef = adaptor.getSource();
188  Value unconvertedMemref = op.getSource();
189  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
190  MemRefDescriptor descriptor(memRef);
191 
192  DataLayout dataLayout = DataLayout::closest(op);
193  int64_t elementByteWidth =
194  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
195 
196  int64_t unusedOffset = 0;
197  SmallVector<int64_t, 5> strideVals;
198  if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
199  return op.emitOpError("Can't lower non-stride-offset memrefs");
200 
201  Value numRecords = adaptor.getValidBytes();
202  if (!numRecords)
203  numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
204  strideVals, elementByteWidth);
205 
206  Value basePointer =
207  adaptor.getResetOffset()
208  ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
209  memrefType)
210  : descriptor.alignedPtr(rewriter, loc);
211 
212  Value offset = adaptor.getResetOffset()
213  ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
214  rewriter.getIndexAttr(0))
215  : descriptor.offset(rewriter, loc);
216 
217  bool hasSizes = memrefType.getRank() > 0;
218  // No need to unpack() and pack() all the individual sizes and strides,
219  // so we'll just extract the arrays.
220  Value sizes = hasSizes
221  ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
223  : Value{};
224  Value strides =
225  hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
227  : Value{};
228 
229  Value fatPtr = makeBufferRsrc(
230  rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
231  chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
232 
234  rewriter, loc,
235  getTypeConverter()->convertType(op.getResult().getType()));
237  result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
238  result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
240  result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
242  if (hasSizes) {
243  result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
245  result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
247  }
248  rewriter.replaceOp(op, result);
249  return success();
250  }
251 };
252 
253 /// Define lowering patterns for raw buffer ops
254 template <typename GpuOp, typename Intrinsic>
255 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
256  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
257  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
258 
259  Chipset chipset;
260  static constexpr uint32_t maxVectorOpWidth = 128;
261 
262  LogicalResult
263  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
264  ConversionPatternRewriter &rewriter) const override {
265  Location loc = gpuOp.getLoc();
266  Value memref = adaptor.getMemref();
267  Value unconvertedMemref = gpuOp.getMemref();
268  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
269 
270  if (chipset.majorVersion < 9)
271  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
272 
273  Value storeData = adaptor.getODSOperands(0)[0];
274  if (storeData == memref) // no write component to this op
275  storeData = Value();
276  Type wantedDataType;
277  if (storeData)
278  wantedDataType = storeData.getType();
279  else
280  wantedDataType = gpuOp.getODSResults(0)[0].getType();
281 
282  Value atomicCmpData = Value();
283  // Operand index 1 of a load is the indices, trying to read them can crash.
284  if (storeData) {
285  Value maybeCmpData = adaptor.getODSOperands(1)[0];
286  if (maybeCmpData != memref)
287  atomicCmpData = maybeCmpData;
288  }
289 
290  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
291 
292  Type i32 = rewriter.getI32Type();
293 
294  // Get the type size in bytes.
295  DataLayout dataLayout = DataLayout::closest(gpuOp);
296  int64_t elementByteWidth =
297  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
298  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
299 
300  // If we want to load a vector<NxT> with total size <= 32
301  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
302  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
303  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
304  // so bitcast any floats to integers.
305  Type llvmBufferValType = llvmWantedDataType;
306  if (atomicCmpData) {
307  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
308  llvmBufferValType = this->getTypeConverter()->convertType(
309  rewriter.getIntegerType(floatType.getWidth()));
310  }
311  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
312  uint32_t vecLen = dataVector.getNumElements();
313  uint32_t elemBits =
314  dataLayout.getTypeSizeInBits(dataVector.getElementType());
315  uint32_t totalBits = elemBits * vecLen;
316  bool usePackedFp16 =
317  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
318  if (totalBits > maxVectorOpWidth)
319  return gpuOp.emitOpError(
320  "Total width of loads or stores must be no more than " +
321  Twine(maxVectorOpWidth) + " bits, but we call for " +
322  Twine(totalBits) +
323  " bits. This should've been caught in validation");
324  if (!usePackedFp16 && elemBits < 32) {
325  if (totalBits > 32) {
326  if (totalBits % 32 != 0)
327  return gpuOp.emitOpError("Load or store of more than 32-bits that "
328  "doesn't fit into words. Can't happen\n");
329  llvmBufferValType = this->typeConverter->convertType(
330  VectorType::get(totalBits / 32, i32));
331  } else {
332  llvmBufferValType = this->typeConverter->convertType(
333  rewriter.getIntegerType(totalBits));
334  }
335  }
336  }
337  if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
338  // Buffer intrinsics doesn't support 1-element vectors, cast them to
339  // scalars.
340  if (vecType.getNumElements() == 1)
341  llvmBufferValType = vecType.getElementType();
342  }
343 
345  if (storeData) {
346  if (llvmBufferValType != llvmWantedDataType) {
347  Value castForStore = LLVM::BitcastOp::create(
348  rewriter, loc, llvmBufferValType, storeData);
349  args.push_back(castForStore);
350  } else {
351  args.push_back(storeData);
352  }
353  }
354 
355  if (atomicCmpData) {
356  if (llvmBufferValType != llvmWantedDataType) {
357  Value castForCmp = LLVM::BitcastOp::create(
358  rewriter, loc, llvmBufferValType, atomicCmpData);
359  args.push_back(castForCmp);
360  } else {
361  args.push_back(atomicCmpData);
362  }
363  }
364 
365  // Construct buffer descriptor from memref, attributes
366  int64_t offset = 0;
367  SmallVector<int64_t, 5> strides;
368  if (failed(memrefType.getStridesAndOffset(strides, offset)))
369  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
370 
371  MemRefDescriptor memrefDescriptor(memref);
372 
373  Value ptr = memrefDescriptor.bufferPtr(
374  rewriter, loc, *this->getTypeConverter(), memrefType);
375  Value numRecords = getNumRecords(
376  rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
377  Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
378  adaptor.getBoundsCheck(), chipset);
379  args.push_back(resource);
380 
381  // Indexing (voffset)
382  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
383  adaptor.getIndices(), strides);
384  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
385  indexOffset && *indexOffset > 0) {
386  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
387  voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
388  extraOffsetConst)
389  : extraOffsetConst;
390  }
391  voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
392  args.push_back(voffset);
393 
394  // SGPR offset.
395  Value sgprOffset = adaptor.getSgprOffset();
396  if (!sgprOffset)
397  sgprOffset = createI32Constant(rewriter, loc, 0);
398  sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
399  args.push_back(sgprOffset);
400 
401  // bit 0: GLC = 0 (atomics drop value, less coherency)
402  // bits 1-2: SLC, DLC = 0 (similarly)
403  // bit 3: swizzled (0 for raw)
404  args.push_back(createI32Constant(rewriter, loc, 0));
405 
406  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
407  llvmBufferValType);
408  Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
410  if (lowered->getNumResults() == 1) {
411  Value replacement = lowered->getResult(0);
412  if (llvmBufferValType != llvmWantedDataType) {
413  replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
414  replacement);
415  }
416  rewriter.replaceOp(gpuOp, replacement);
417  } else {
418  rewriter.eraseOp(gpuOp);
419  }
420  return success();
421  }
422 };
423 
424 // TODO: AMDGPU backend already have all this bitpacking logic, we should move
425 // it to some common place.
426 /// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
427 /// Vmcnt = Waitcnt[3:0] (pre-gfx9)
428 /// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
429 /// Vmcnt = Waitcnt[15:10] (gfx11)
430 /// Expcnt = Waitcnt[6:4] (pre-gfx11)
431 /// Expcnt = Waitcnt[2:0] (gfx11)
432 /// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
433 /// Lgkmcnt = Waitcnt[13:8] (gfx10)
434 /// Lgkmcnt = Waitcnt[9:4] (gfx11)
435 static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
436  unsigned expcnt, unsigned lgkmcnt) {
437  if (chipset.majorVersion < 9) {
438  vmcnt = std::min(15u, vmcnt);
439  expcnt = std::min(7u, expcnt);
440  lgkmcnt = std::min(15u, lgkmcnt);
441  return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
442  }
443  if (chipset.majorVersion == 9) {
444  vmcnt = std::min(63u, vmcnt);
445  expcnt = std::min(7u, expcnt);
446  lgkmcnt = std::min(15u, lgkmcnt);
447  unsigned lowBits = vmcnt & 0xF;
448  unsigned highBits = (vmcnt >> 4) << 14;
449  unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
450  return lowBits | highBits | otherCnts;
451  }
452  if (chipset.majorVersion == 10) {
453  vmcnt = std::min(63u, vmcnt);
454  expcnt = std::min(7u, expcnt);
455  lgkmcnt = std::min(63u, lgkmcnt);
456  unsigned lowBits = vmcnt & 0xF;
457  unsigned highBits = (vmcnt >> 4) << 14;
458  unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
459  return lowBits | highBits | otherCnts;
460  }
461  if (chipset.majorVersion == 11) {
462  vmcnt = std::min(63u, vmcnt);
463  expcnt = std::min(7u, expcnt);
464  lgkmcnt = std::min(63u, lgkmcnt);
465  return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
466  }
467  return failure();
468 }
469 
470 struct MemoryCounterWaitOpLowering
471  : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
472  MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
473  Chipset chipset)
474  : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
475  chipset(chipset) {}
476 
477  Chipset chipset;
478 
479  LogicalResult
480  matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
481  ConversionPatternRewriter &rewriter) const override {
482  if (chipset.majorVersion >= 12) {
483  Location loc = op.getLoc();
484  if (std::optional<int> ds = adaptor.getDs())
485  ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
486 
487  if (std::optional<int> load = adaptor.getLoad())
488  ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
489 
490  if (std::optional<int> store = adaptor.getStore())
491  ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
492 
493  if (std::optional<int> exp = adaptor.getExp())
494  ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
495 
496  rewriter.eraseOp(op);
497  return success();
498  }
499 
500  auto getVal = [](Attribute attr) -> unsigned {
501  if (attr)
502  return cast<IntegerAttr>(attr).getInt();
503 
504  // This value will be clamped to the maximum value for the chipset.
505  return 1024;
506  };
507  unsigned ds = getVal(adaptor.getDsAttr());
508  unsigned exp = getVal(adaptor.getExpAttr());
509 
510  unsigned vmcnt = 1024;
511  Attribute load = adaptor.getLoadAttr();
512  Attribute store = adaptor.getStoreAttr();
513  if (load && store) {
514  vmcnt = getVal(load) + getVal(store);
515  } else if (load) {
516  vmcnt = getVal(load);
517  } else if (store) {
518  vmcnt = getVal(store);
519  }
520 
521  FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
522  if (failed(waitcnt))
523  return op.emitOpError("unsupported chipset");
524 
525  rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
526  return success();
527  }
528 };
529 
530 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
531  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
532  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
533 
534  Chipset chipset;
535 
536  LogicalResult
537  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
540 
541  if (requiresInlineAsm) {
542  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
543  LLVM::AsmDialect::AD_ATT);
544  const char *asmStr =
545  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
546  const char *constraints = "";
547  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
548  op,
549  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
550  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
551  /*is_align_stack=*/false, LLVM::TailCallKind::None,
552  /*asm_dialect=*/asmDialectAttr,
553  /*operand_attrs=*/ArrayAttr());
554  return success();
555  }
556  if (chipset.majorVersion < 12) {
557  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
558  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
559  // Left in place in case someone disables the inline ASM path or future
560  // chipsets use the same bit pattern.
561  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
562 
563  int32_t ldsOnlyBits;
564  if (chipset.majorVersion == 11)
565  ldsOnlyBits = ldsOnlyBitsGfx11;
566  else if (chipset.majorVersion == 10)
567  ldsOnlyBits = ldsOnlyBitsGfx10;
568  else if (chipset.majorVersion <= 9)
569  ldsOnlyBits = ldsOnlyBitsGfx6789;
570  else
571  return op.emitOpError(
572  "don't know how to lower this for chipset major version")
573  << chipset.majorVersion;
574 
575  Location loc = op->getLoc();
576  ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
577  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
578  } else {
579  Location loc = op->getLoc();
580  ROCDL::WaitDscntOp::create(rewriter, loc, 0);
581  ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
582  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
583  }
584 
585  return success();
586  }
587 };
588 
589 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
590  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
591  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
592 
593  Chipset chipset;
594 
595  LogicalResult
596  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
597  ConversionPatternRewriter &rewriter) const override {
598  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
599  (uint32_t)op.getOpts());
600  return success();
601  }
602 };
603 
604 } // namespace
605 
606 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
607 /// and LLVM AMDGPU intrinsics convention.
608 ///
609 /// Specifically:
610 /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
611 /// allows bf16. Newer MFMAs support bf16 types on operand, check
612 /// IntrinsicsAMDGPU.td file for reference.
613 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
614 /// instead, which is what the f8f6f4 intrinsics use.
615 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
616 /// integer.
617 ///
618 /// Note that the type of `input` has already been LLVM type converted:
619 /// therefore 8-bit and smaller floats are represented as their corresponding
620 /// `iN` integers.
622  Location loc, Value input,
623  bool allowBf16 = true) {
624  Type inputType = input.getType();
625  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
626  if (vectorType.getElementType().isBF16() && !allowBf16)
627  return LLVM::BitcastOp::create(
628  rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
629  if (vectorType.getElementType().isInteger(8) &&
630  vectorType.getNumElements() <= 8)
631  return LLVM::BitcastOp::create(
632  rewriter, loc,
633  rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
634  if (isa<IntegerType>(vectorType.getElementType()) &&
635  vectorType.getElementTypeBitWidth() <= 8) {
636  int64_t numWords = llvm::divideCeil(
637  vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
638  32);
639  return LLVM::BitcastOp::create(
640  rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
641  input);
642  }
643  }
644  return input;
645 }
646 
647 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
648 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
649 ///
650 /// Specifically:
651 /// 1. If `input` is a i8 value, zero extend it to i32
652 /// 2. If `input` is a vector of length 4 and type i8, cast it to i32
653 ///
654 /// Note that the type of `input` has already been LLVM type converted:
655 /// therefore 8-bit and smaller floats are represented as their corresponding
656 /// `iN` integers.
658  Location loc, Value input) {
659  Type inputType = input.getType();
660  Type outputType = rewriter.getI32Type();
661  if (auto intType = dyn_cast<IntegerType>(inputType))
662  return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
663  return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
664 }
665 
666 /// Push an input operand. If it is a float type, nothing to do. If it is
667 /// an integer type, then we need to also push its signdness (1 for signed, 0
668 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
669 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
670 /// We also need to convert bfloat inputs to i16 to account for the bfloat
671 /// intrinsics having been defined before the AMD backend supported bfloat. We
672 /// similarly need to pack 8-bit float types into integers as if they were i8
673 /// (which they are for the backend's purposes).
675  Location loc,
676  const TypeConverter *typeConverter,
677  bool isUnsigned, Value llvmInput,
678  Value mlirInput,
679  SmallVector<Value, 4> &operands) {
680  Type inputType = llvmInput.getType();
681  auto vectorType = dyn_cast<VectorType>(inputType);
682  if (!vectorType) {
683  operands.push_back(llvmInput);
684  return;
685  }
686  Type elemType = vectorType.getElementType();
687 
688  if (elemType.isBF16())
689  llvmInput = LLVM::BitcastOp::create(
690  rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
691  if (elemType.getIntOrFloatBitWidth() > 8) {
692  operands.push_back(llvmInput);
693  return;
694  }
695 
696  // We need to check the type of the input before conversion to properly test
697  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
698  // fp8/int8 information is lost during the conversion process.
699  auto mlirInputType = cast<VectorType>(mlirInput.getType());
700  bool isInputInteger = mlirInputType.getElementType().isInteger();
701  if (isInputInteger) {
702  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
703  bool localIsUnsigned = isUnsigned;
704  if (elemType.isUnsignedInteger()) {
705  localIsUnsigned = true;
706  } else if (elemType.isSignedInteger()) {
707  localIsUnsigned = false;
708  }
709  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
710  operands.push_back(sign);
711  }
712 
713  int64_t numBits =
714  vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
715  Type i32 = rewriter.getI32Type();
716  Type intrinsicInType = numBits <= 32
717  ? (Type)rewriter.getIntegerType(numBits)
718  : (Type)VectorType::get(numBits / 32, i32);
719  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
720  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
721  loc, llvmIntrinsicInType, llvmInput);
722  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
723  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
724  // Add in the zeros here.
725  if (numBits < 32)
726  castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
727  operands.push_back(castInput);
728 }
729 
730 /// Push the output operand. For many cases this is only pushing the output in
731 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
732 /// since the same numbers of VGPRs is used, we need to decide if to store the
733 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
734 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
735 /// be stored it in the upper part. The subwordOffset must not be set for gfx12,
736 /// as the instructions have been changed to return fewer registers instead.
738  Location loc,
739  const TypeConverter *typeConverter,
740  Value output, int32_t subwordOffset,
741  bool clamp, SmallVector<Value, 4> &operands) {
742  Type inputType = output.getType();
743  auto vectorType = dyn_cast<VectorType>(inputType);
744  Type elemType = vectorType.getElementType();
745  if (elemType.isBF16())
746  output = LLVM::BitcastOp::create(
747  rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
748  operands.push_back(output);
749  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
750  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
751  } else if (elemType.isInteger(32)) {
752  operands.push_back(createI1Constant(rewriter, loc, clamp));
753  }
754 }
755 
756 /// Return true if `type` is the E5M2 variant of an 8-bit float that is
757 /// supported by the `_bf8` instructions on the given `chipset`.
758 static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
759  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
760  (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
761 }
762 
763 /// Return true if `type` is the E4M3FN variant of an 8-bit float that is
764 /// supported by the `_fp8` instructions on the given `chipset`.
765 static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
766  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
767  (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
768 }
769 
770 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
771 /// if one exists. This includes checking to ensure the intrinsic is supported
772 /// on the architecture you are compiling for.
773 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
774  Chipset chipset) {
775  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
776  b = mfma.getBlocks();
777  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
778  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
779 
780  if (sourceElem.isF32() && destElem.isF32()) {
781  if (mfma.getReducePrecision() && chipset >= kGfx942) {
782  if (m == 32 && n == 32 && k == 4 && b == 1)
783  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
784  if (m == 16 && n == 16 && k == 8 && b == 1)
785  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
786  }
787  if (m == 32 && n == 32 && k == 1 && b == 2)
788  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
789  if (m == 16 && n == 16 && k == 1 && b == 4)
790  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
791  if (m == 4 && n == 4 && k == 1 && b == 16)
792  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
793  if (m == 32 && n == 32 && k == 2 && b == 1)
794  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
795  if (m == 16 && n == 16 && k == 4 && b == 1)
796  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
797  }
798 
799  if (sourceElem.isF16() && destElem.isF32()) {
800  if (chipset >= kGfx950) {
801  if (m == 32 && n == 32 && k == 16 && b == 1)
802  return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
803  if (m == 16 && n == 16 && k == 32 && b == 1)
804  return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
805  }
806  if (m == 32 && n == 32 && k == 4 && b == 2)
807  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
808  if (m == 16 && n == 16 && k == 4 && b == 4)
809  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
810  if (m == 4 && n == 4 && k == 4 && b == 16)
811  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
812  if (m == 32 && n == 32 && k == 8 && b == 1)
813  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
814  if (m == 16 && n == 16 && k == 16 && b == 1)
815  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
816  }
817 
818  if (sourceElem.isBF16() && destElem.isF32()) {
819  if (chipset >= kGfx950) {
820  if (m == 32 && n == 32 && k == 16 && b == 1)
821  return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
822  if (m == 16 && n == 16 && k == 32 && b == 1)
823  return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
824  }
825  if (chipset >= kGfx90a) {
826  if (m == 32 && n == 32 && k == 4 && b == 2)
827  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
828  if (m == 16 && n == 16 && k == 4 && b == 4)
829  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
830  if (m == 4 && n == 4 && k == 4 && b == 16)
831  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
832  if (m == 32 && n == 32 && k == 8 && b == 1)
833  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
834  if (m == 16 && n == 16 && k == 16 && b == 1)
835  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
836  }
837  if (m == 32 && n == 32 && k == 2 && b == 2)
838  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
839  if (m == 16 && n == 16 && k == 2 && b == 4)
840  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
841  if (m == 4 && n == 4 && k == 2 && b == 16)
842  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
843  if (m == 32 && n == 32 && k == 4 && b == 1)
844  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
845  if (m == 16 && n == 16 && k == 8 && b == 1)
846  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
847  }
848 
849  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
850  if (chipset >= kGfx950) {
851  if (m == 32 && n == 32 && k == 32 && b == 1)
852  return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
853  if (m == 16 && n == 16 && k == 64 && b == 1)
854  return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
855  }
856  if (m == 32 && n == 32 && k == 4 && b == 2)
857  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
858  if (m == 16 && n == 16 && k == 4 && b == 4)
859  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
860  if (m == 4 && n == 4 && k == 4 && b == 16)
861  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
862  if (m == 32 && n == 32 && k == 8 && b == 1)
863  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
864  if (m == 16 && n == 16 && k == 16 && b == 1)
865  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
866  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
867  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
868  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
869  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
870  }
871 
872  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
873  if (m == 16 && n == 16 && k == 4 && b == 1)
874  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
875  if (m == 4 && n == 4 && k == 4 && b == 4)
876  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
877  }
878 
879  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
880  // Known to be correct because there are no scalar f8 instructions and
881  // because a length mismatch will have been caught by the verifier.
882  Type sourceBElem =
883  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
884  if (m == 16 && n == 16 && k == 32 && b == 1) {
885  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
886  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
887  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
888  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
889  }
890  if (m == 32 && n == 32 && k == 16 && b == 1) {
891  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
892  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
893  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
894  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
895  }
896  }
897 
898  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
899  Type sourceBElem =
900  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
901  if (m == 16 && n == 16 && k == 32 && b == 1) {
902  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
903  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
904  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
905  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
906  }
907  if (m == 32 && n == 32 && k == 16 && b == 1) {
908  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
909  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
910  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
911  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
912  }
913  }
914 
915  return std::nullopt;
916 }
917 
918 static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
920  .Case([](Float8E4M3FNType) { return 0u; })
921  .Case([](Float8E5M2Type) { return 1u; })
922  .Case([](Float6E2M3FNType) { return 2u; })
923  .Case([](Float6E3M2FNType) { return 3u; })
924  .Case([](Float4E2M1FNType) { return 4u; })
925  .Default([](Type) { return std::nullopt; });
926 }
927 
928 /// If there is a scaled MFMA instruction for the input element types `aType`
929 /// and `bType`, output type `destType`, problem size M, N, K, and B (number of
930 /// blocks) on the given `chipset`, return a tuple consisting of the
931 /// OperationName of the intrinsic and the type codes that need to be passed to
932 /// that intrinsic. Note that this is also used to implement some un-scaled
933 /// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
934 /// MFMA with a scale of 0.
935 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
936 mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
937  uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
938  aType = getElementTypeOrSelf(aType);
939  bType = getElementTypeOrSelf(bType);
940  destType = getElementTypeOrSelf(destType);
941 
942  if (chipset < kGfx950)
943  return std::nullopt;
944  if (!isa<Float32Type>(destType))
945  return std::nullopt;
946 
947  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
948  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
949  if (!aTypeCode || !bTypeCode)
950  return std::nullopt;
951 
952  if (m == 32 && n == 32 && k == 64 && b == 1)
953  return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
954  *aTypeCode, *bTypeCode};
955  if (m == 16 && n == 16 && k == 128 && b == 1)
956  return std::tuple{
957  ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
958  *bTypeCode};
959 
960  return std::nullopt;
961 }
962 
963 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
964 mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
966  mfma.getSourceA().getType(), mfma.getSourceB().getType(),
967  mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
968  mfma.getBlocks(), chipset);
969 }
970 
971 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
972 mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
973  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
974  smfma.getSourceB().getType(),
975  smfma.getDestC().getType(), smfma.getM(),
976  smfma.getN(), smfma.getK(), 1u, chipset);
977 }
978 
979 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
980 /// if one exists. This includes checking to ensure the intrinsic is supported
981 /// on the architecture you are compiling for.
982 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
983  Chipset chipset) {
984  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
985  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
986  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
987  auto elemSourceType = sourceVectorType.getElementType();
988  auto elemBSourceType = sourceBVectorType.getElementType();
989  auto elemDestType = destVectorType.getElementType();
990 
991  if (elemSourceType.isF16() && elemDestType.isF32())
992  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
993  if (elemSourceType.isBF16() && elemDestType.isF32())
994  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
995  if (elemSourceType.isF16() && elemDestType.isF16())
996  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
997  if (elemSourceType.isBF16() && elemDestType.isBF16())
998  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
999  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1000  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1001  if (chipset.majorVersion == 11) {
1002  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1003  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1004  }
1005  if (chipset.majorVersion >= 12) {
1006  if (isa<Float8E4M3FNType>(elemSourceType) &&
1007  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1008  return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1009  if (isa<Float8E4M3FNType>(elemSourceType) &&
1010  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1011  return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1012  if (isa<Float8E5M2Type>(elemSourceType) &&
1013  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1014  return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1015  if (isa<Float8E5M2Type>(elemSourceType) &&
1016  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1017  return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1018  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
1019  bool isWave64 = destVectorType.getNumElements() == 4;
1020  // This is the ambiguous case. 8 inputs to the wave64 version means that
1021  // we want the 16x16x32 version, but for wave32 they mean the short form.
1022  bool has8Inputs = sourceVectorType.getNumElements() == 8;
1023  if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
1024  return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1025  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1026  }
1027  }
1028  return std::nullopt;
1029 }
1030 
1031 namespace {
1032 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1033  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1034  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1035 
1036  Chipset chipset;
1037 
1038  LogicalResult
1039  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1040  ConversionPatternRewriter &rewriter) const override {
1041  Location loc = op.getLoc();
1042  Type outType = typeConverter->convertType(op.getDestD().getType());
1043  Type intrinsicOutType = outType;
1044  if (auto outVecType = dyn_cast<VectorType>(outType))
1045  if (outVecType.getElementType().isBF16())
1046  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1047 
1048  if (chipset.majorVersion != 9 || chipset < kGfx908)
1049  return op->emitOpError("MFMA only supported on gfx908+");
1050  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1051  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1052  if (chipset < kGfx942)
1053  return op.emitOpError("negation unsupported on older than gfx942");
1054  getBlgpField |=
1055  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1056  }
1057  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1058  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1059  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1060  if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1061  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1062 
1063  bool isScaled =
1064  !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1065  if (isScaled &&
1066  (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1067  return op.emitOpError(
1068  "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1069  "be scaled as those fields are used for type information");
1070  }
1071 
1072  StringRef intrinsicName =
1073  isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1074  // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1075  // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1076  bool allowBf16 = [&]() {
1077  if (chipset < kGfx950)
1078  return false;
1079  if (isScaled)
1080  return true;
1081  return intrinsicName.contains("16x16x32.bf16") ||
1082  intrinsicName.contains("32x32x16.bf16");
1083  }();
1084  OperationState loweredOp(loc, intrinsicName);
1085  loweredOp.addTypes(intrinsicOutType);
1086  loweredOp.addOperands({convertMFMAVectorOperand(
1087  rewriter, loc, adaptor.getSourceA(), allowBf16),
1089  rewriter, loc, adaptor.getSourceB(), allowBf16),
1090  adaptor.getDestC()});
1091  if (isScaled) {
1092  Value zero = createI32Constant(rewriter, loc, 0);
1093  auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1094  loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1095  createI32Constant(rewriter, loc, bTypeCode),
1096  /*scale A byte=*/zero, /*scale A=*/zero,
1097  /*scale B byte=*/zero, /*scale B=*/zero});
1098  } else {
1099  loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1100  createI32Constant(rewriter, loc, op.getAbid()),
1101  createI32Constant(rewriter, loc, getBlgpField)});
1102  };
1103  Value lowered = rewriter.create(loweredOp)->getResult(0);
1104  if (outType != intrinsicOutType)
1105  lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1106  rewriter.replaceOp(op, lowered);
1107  return success();
1108  }
1109 };
1110 
1111 struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1112  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1113  : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1114 
1115  Chipset chipset;
1116 
1117  LogicalResult
1118  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1119  ConversionPatternRewriter &rewriter) const override {
1120  Location loc = op.getLoc();
1121  Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1122 
1123  if (chipset.majorVersion != 9 || chipset < kGfx950)
1124  return op->emitOpError("scaled MFMA only supported on gfx908+");
1125  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1126  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1127  if (!maybeScaledIntrinsic.has_value())
1128  return op.emitOpError(
1129  "no intrinsic matching scaled MFMA size on given chipset");
1130 
1131  auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1132  OperationState loweredOp(loc, intrinsicName);
1133  loweredOp.addTypes(intrinsicOutType);
1134  loweredOp.addOperands(
1135  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1136  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1137  adaptor.getDestC()});
1138  Value scalesIdxA =
1139  createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1140  Value scalesIdxB =
1141  createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1142  loweredOp.addOperands(
1143  {createI32Constant(rewriter, loc, aTypeCode),
1144  createI32Constant(rewriter, loc, bTypeCode),
1145  /*scales idx A=*/scalesIdxA,
1146  /*scales A*/
1147  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1148  /*scales idx B=*/scalesIdxB,
1149  /*scales B*/
1150  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1151  Value lowered = rewriter.create(loweredOp)->getResult(0);
1152  rewriter.replaceOp(op, lowered);
1153  return success();
1154  }
1155 };
1156 
1157 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1158  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1159  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1160 
1161  Chipset chipset;
1162 
1163  LogicalResult
1164  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1165  ConversionPatternRewriter &rewriter) const override {
1166  Location loc = op.getLoc();
1167  auto outType =
1168  typeConverter->convertType<VectorType>(op.getDestD().getType());
1169  if (!outType)
1170  return rewriter.notifyMatchFailure(op, "type conversion failed");
1171 
1172  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1173  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1174 
1175  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1176  // need to bitcast bfloats to i16 and then bitcast them back.
1177  VectorType rawOutType = outType;
1178  if (outType.getElementType().isBF16())
1179  rawOutType = outType.clone(rewriter.getI16Type());
1180 
1181  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1182 
1183  if (!maybeIntrinsic.has_value())
1184  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1185 
1186  if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1187  return op.emitOpError("subwordOffset not supported on gfx12+");
1188 
1189  OperationState loweredOp(loc, *maybeIntrinsic);
1190  loweredOp.addTypes(rawOutType);
1191 
1192  SmallVector<Value, 4> operands;
1193  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1194  adaptor.getSourceA(), op.getSourceA(), operands);
1195  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1196  adaptor.getSourceB(), op.getSourceB(), operands);
1197  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1198  op.getSubwordOffset(), op.getClamp(), operands);
1199 
1200  loweredOp.addOperands(operands);
1201  Operation *lowered = rewriter.create(loweredOp);
1202 
1203  Operation *maybeCastBack = lowered;
1204  if (rawOutType != outType)
1205  maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1206  lowered->getResult(0));
1207  rewriter.replaceOp(op, maybeCastBack->getResults());
1208 
1209  return success();
1210  }
1211 };
1212 
1213 struct TransposeLoadOpLowering
1214  : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1215  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1216  : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1217 
1218  Chipset chipset;
1219 
1220  LogicalResult
1221  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1222  ConversionPatternRewriter &rewriter) const override {
1223  if (chipset != kGfx950)
1224  return op.emitOpError("Non-gfx950 chipset not supported");
1225 
1226  Location loc = op.getLoc();
1227  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1228 
1229  // Elements in subbyte memrefs are stored non-contiguously,
1230  // reject if source is sub-byte memref. Use emulated memrefs instead.
1231  size_t srcElementSize =
1232  srcMemRefType.getElementType().getIntOrFloatBitWidth();
1233  if (srcElementSize < 8)
1234  return op.emitOpError("Expect source memref to have at least 8 bits "
1235  "element size, got ")
1236  << srcElementSize;
1237 
1238  auto resultType = cast<VectorType>(op.getResult().getType());
1239  Value srcPtr =
1240  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1241  (adaptor.getSrcIndices()));
1242 
1243  size_t numElements = resultType.getNumElements();
1244  size_t elementTypeSize =
1245  resultType.getElementType().getIntOrFloatBitWidth();
1246 
1247  // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1248  // the element size is smaller than 16 bits.
1249  Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1250  rewriter.getIntegerType(32));
1251  Type llvmResultType = typeConverter->convertType(resultType);
1252 
1253  switch (elementTypeSize) {
1254  case 4: {
1255  assert(numElements == 16);
1256  auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1257  rocdlResultType, srcPtr);
1258  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1259  break;
1260  }
1261  case 6: {
1262  assert(numElements == 16);
1263  auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1264  rocdlResultType, srcPtr);
1265  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1266  break;
1267  }
1268  case 8: {
1269  assert(numElements == 8);
1270  auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1271  rocdlResultType, srcPtr);
1272  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1273  break;
1274  }
1275  case 16: {
1276  assert(numElements == 4);
1277  rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1278  srcPtr);
1279  break;
1280  }
1281  default:
1282  return op.emitOpError("Unsupported element size for transpose load");
1283  }
1284  return success();
1285  }
1286 };
1287 
1288 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1289  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1290  : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1291 
1292  Chipset chipset;
1293 
1294  LogicalResult
1295  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1296  ConversionPatternRewriter &rewriter) const override {
1297  if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1298  return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1299 
1300  Location loc = op.getLoc();
1301 
1302  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1303  auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1304 
1305  // TODO: instead of only transfering one element per thread, we could
1306  // augment it to transfer multiple elements per thread by issuing multiple
1307  // `global_load_lds` instructions.
1308  Type transferType = op.getTransferType();
1309  int loadWidth = [&]() -> int {
1310  if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1311  return (transferVectorType.getNumElements() *
1312  transferVectorType.getElementTypeBitWidth()) /
1313  8;
1314  }
1315  return transferType.getIntOrFloatBitWidth() / 8;
1316  }();
1317 
1318  // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1319  if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1320  return op.emitOpError("chipset unsupported element size");
1321 
1322  if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1323  return op.emitOpError("Gather to LDS instructions with 12-byte and "
1324  "16-byte load widths are only supported on gfx950");
1325 
1326  Value srcPtr =
1327  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1328  (adaptor.getSrcIndices()));
1329  Value dstPtr =
1330  getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1331  (adaptor.getDstIndices()));
1332 
1333  rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1334  op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1335  /*offset=*/rewriter.getI32IntegerAttr(0),
1336  /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1337  ArrayAttr{});
1338 
1339  return success();
1340  }
1341 };
1342 
1343 namespace {
1344 struct ExtPackedFp8OpLowering final
1345  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1346  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1347  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1348  chipset(chipset) {}
1349  Chipset chipset;
1350 
1351  LogicalResult
1352  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1353  ConversionPatternRewriter &rewriter) const override;
1354 };
1355 
1356 struct PackedTrunc2xFp8OpLowering final
1357  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1358  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1359  Chipset chipset)
1360  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1361  chipset(chipset) {}
1362  Chipset chipset;
1363 
1364  LogicalResult
1365  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1366  ConversionPatternRewriter &rewriter) const override;
1367 };
1368 
1369 struct PackedStochRoundFp8OpLowering final
1370  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1371  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1372  Chipset chipset)
1373  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1374  chipset(chipset) {}
1375  Chipset chipset;
1376 
1377  LogicalResult
1378  matchAndRewrite(PackedStochRoundFp8Op op,
1379  PackedStochRoundFp8OpAdaptor adaptor,
1380  ConversionPatternRewriter &rewriter) const override;
1381 };
1382 
1383 struct ScaledExtPackedOpLowering final
1384  : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1385  ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1386  : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1387  chipset(chipset) {}
1388  Chipset chipset;
1389 
1390  LogicalResult
1391  matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1392  ConversionPatternRewriter &rewriter) const override;
1393 };
1394 
1395 struct PackedScaledTruncOpLowering final
1396  : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1397  PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1398  Chipset chipset)
1399  : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1400  chipset(chipset) {}
1401  Chipset chipset;
1402 
1403  LogicalResult
1404  matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1405  ConversionPatternRewriter &rewriter) const override;
1406 };
1407 
1408 } // end namespace
1409 
1410 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1411  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1412  ConversionPatternRewriter &rewriter) const {
1413  Location loc = op.getLoc();
1414  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1415  return rewriter.notifyMatchFailure(
1416  loc, "Fp8 conversion instructions are not available on target "
1417  "architecture and their emulation is not implemented");
1418  Type v4i8 =
1419  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1420  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1421  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1422 
1423  Value source = adaptor.getSource();
1424  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1425  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1426  Type sourceElemType = getElementTypeOrSelf(op.getSource());
1427  // Extend to a v4i8
1428  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1429  Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1430  if (!sourceVecType) {
1431  longVec = LLVM::InsertElementOp::create(
1432  rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1433  } else {
1434  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1435  Value idx = createI32Constant(rewriter, loc, i);
1436  Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1437  longVec =
1438  LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1439  }
1440  }
1441  source = longVec;
1442  }
1443  Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1444  if (resultVecType) {
1445  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1446  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1447  op.getIndex());
1448  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1449  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1450  op.getIndex());
1451  }
1452  } else {
1453  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1454  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1455  op.getIndex());
1456  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1457  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1458  op.getIndex());
1459  }
1460  }
1461  return success();
1462 }
1463 
1464 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1465  ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1466  ConversionPatternRewriter &rewriter) const {
1467  Location loc = op.getLoc();
1468  if (chipset != kGfx950)
1469  return rewriter.notifyMatchFailure(
1470  loc, "Scaled fp conversion instructions are not available on target "
1471  "architecture and their emulation is not implemented");
1472  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1473 
1474  Value source = adaptor.getSource();
1475  Value scale = adaptor.getScale();
1476 
1477  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1478  Type sourceElemType = sourceVecType.getElementType();
1479  VectorType destVecType = cast<VectorType>(op.getResult().getType());
1480  Type destElemType = destVecType.getElementType();
1481 
1482  VectorType packedVecType;
1483  if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1484  VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1485  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1486  } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1487  VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1488  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1489  } else {
1490  llvm_unreachable("invalid element type for scaled ext");
1491  }
1492 
1493  // Extend to a packedVectorType
1494  if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1495  Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1496  if (!sourceVecType) {
1497  longVec = LLVM::InsertElementOp::create(
1498  rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1499  } else {
1500  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1501  Value idx = createI32Constant(rewriter, loc, i);
1502  Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1503  longVec =
1504  LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1505  }
1506  }
1507  source = longVec;
1508  }
1509  Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1510 
1511  if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1512  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1513  op, destVecType, i32Source, scale, op.getIndex());
1514  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1515  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1516  op, destVecType, i32Source, scale, op.getIndex());
1517  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1518  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1519  op, destVecType, i32Source, scale, op.getIndex());
1520  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1521  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1522  op, destVecType, i32Source, scale, op.getIndex());
1523  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1524  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1525  op, destVecType, i32Source, scale, op.getIndex());
1526  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1527  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1528  op, destVecType, i32Source, scale, op.getIndex());
1529  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1530  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1531  op, destVecType, i32Source, scale, op.getIndex());
1532  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1533  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1534  op, destVecType, i32Source, scale, op.getIndex());
1535  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1536  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1537  op, destVecType, i32Source, scale, op.getIndex());
1538  else
1539  return failure();
1540 
1541  return success();
1542 }
1543 
1544 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1545  PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1546  ConversionPatternRewriter &rewriter) const {
1547  Location loc = op.getLoc();
1548  if (chipset != kGfx950)
1549  return rewriter.notifyMatchFailure(
1550  loc, "Scaled fp conversion instructions are not available on target "
1551  "architecture and their emulation is not implemented");
1552  Type v2i16 = getTypeConverter()->convertType(
1553  VectorType::get(2, rewriter.getI16Type()));
1554  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1555 
1556  Type resultType = op.getResult().getType();
1557  Type resultElemType = getElementTypeOrSelf(resultType);
1558  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1559  Type sourceElemType = sourceVecType.getElementType();
1560 
1561  Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1562 
1563  Value source = adaptor.getSource();
1564  Value scale = adaptor.getScale();
1565  Value existing = adaptor.getExisting();
1566  if (existing)
1567  existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1568  else
1569  existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1570 
1571  if (sourceVecType.getNumElements() < 2) {
1572  Value c0 = createI32Constant(rewriter, loc, 0);
1573  Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1574  VectorType v2 = VectorType::get(2, sourceElemType);
1575  source = LLVM::ZeroOp::create(rewriter, loc, v2);
1576  source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1577  }
1578 
1579  Value sourceA, sourceB;
1580  if (sourceElemType.isF32()) {
1581  Value c0 = createI32Constant(rewriter, loc, 0);
1582  Value c1 = createI32Constant(rewriter, loc, 1);
1583  sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1584  sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1585  }
1586 
1587  Value result;
1588  if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1589  result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1590  existing, sourceA, sourceB,
1591  scale, op.getIndex());
1592  else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1593  result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1594  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1595  else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1596  result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1597  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1598  else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1599  result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1600  existing, sourceA, sourceB,
1601  scale, op.getIndex());
1602  else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1603  result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1604  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1605  else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1606  result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1607  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1608  else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1609  result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1610  existing, sourceA, sourceB,
1611  scale, op.getIndex());
1612  else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1613  result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1614  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1615  else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1616  result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1617  rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1618  else
1619  return failure();
1620 
1621  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1622  op, getTypeConverter()->convertType(resultType), result);
1623  return success();
1624 }
1625 
1626 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1627  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1628  ConversionPatternRewriter &rewriter) const {
1629  Location loc = op.getLoc();
1630  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1631  return rewriter.notifyMatchFailure(
1632  loc, "Fp8 conversion instructions are not available on target "
1633  "architecture and their emulation is not implemented");
1634  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1635 
1636  Type resultType = op.getResult().getType();
1637  Type resultElemType = getElementTypeOrSelf(resultType);
1638 
1639  Value sourceA = adaptor.getSourceA();
1640  Value sourceB = adaptor.getSourceB();
1641  if (!sourceB)
1642  sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
1643  Value existing = adaptor.getExisting();
1644  if (existing)
1645  existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1646  else
1647  existing = LLVM::UndefOp::create(rewriter, loc, i32);
1648 
1649  Value result;
1650  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1651  result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1652  existing, op.getWordIndex());
1653  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1654  result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1655  existing, op.getWordIndex());
1656 
1657  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1658  op, getTypeConverter()->convertType(resultType), result);
1659  return success();
1660 }
1661 
1662 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1663  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1664  ConversionPatternRewriter &rewriter) const {
1665  Location loc = op.getLoc();
1666  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1667  return rewriter.notifyMatchFailure(
1668  loc, "Fp8 conversion instructions are not available on target "
1669  "architecture and their emulation is not implemented");
1670  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1671 
1672  Type resultType = op.getResult().getType();
1673  Type resultElemType = getElementTypeOrSelf(resultType);
1674 
1675  Value source = adaptor.getSource();
1676  Value stoch = adaptor.getStochiasticParam();
1677  Value existing = adaptor.getExisting();
1678  if (existing)
1679  existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1680  else
1681  existing = LLVM::UndefOp::create(rewriter, loc, i32);
1682 
1683  Value result;
1684  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1685  result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
1686  existing, op.getStoreIndex());
1687  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1688  result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
1689  existing, op.getStoreIndex());
1690 
1691  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1692  op, getTypeConverter()->convertType(resultType), result);
1693  return success();
1694 }
1695 
1696 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1697 // operation into the corresponding ROCDL instructions.
1698 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1699  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1700  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1701  Chipset chipset;
1702 
1703  LogicalResult
1704  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1705  ConversionPatternRewriter &rewriter) const override {
1706 
1707  // Convert the source operand to the corresponding LLVM type
1708  Location loc = DppOp.getLoc();
1709  Value src = adaptor.getSrc();
1710  Value old = adaptor.getOld();
1711  Type srcType = src.getType();
1712  Type oldType = old.getType();
1713  Type llvmType = nullptr;
1714  if (srcType.getIntOrFloatBitWidth() < 32) {
1715  llvmType = rewriter.getI32Type();
1716  } else if (isa<FloatType>(srcType)) {
1717  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1718  ? rewriter.getF32Type()
1719  : rewriter.getF64Type();
1720  } else if (isa<IntegerType>(srcType)) {
1721  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1722  ? rewriter.getI32Type()
1723  : rewriter.getI64Type();
1724  }
1725  auto llvmSrcIntType = typeConverter->convertType(
1726  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1727 
1728  // If the source type is less of 32, use bitcast to convert it to i32.
1729  auto convertOperand = [&](Value operand, Type operandType) {
1730  if (operandType.getIntOrFloatBitWidth() <= 16) {
1731  if (llvm::isa<FloatType>(operandType)) {
1732  operand =
1733  LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
1734  }
1735  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1736  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1737  Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
1738  operand =
1739  LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
1740  createI32Constant(rewriter, loc, 0));
1741  operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
1742  }
1743  return operand;
1744  };
1745 
1746  src = convertOperand(src, srcType);
1747  old = convertOperand(old, oldType);
1748 
1749  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1750  enum DppCtrl : unsigned {
1751  ROW_SHL0 = 0x100,
1752  ROW_SHR0 = 0x110,
1753  ROW_ROR0 = 0x120,
1754  WAVE_SHL1 = 0x130,
1755  WAVE_ROL1 = 0x134,
1756  WAVE_SHR1 = 0x138,
1757  WAVE_ROR1 = 0x13C,
1758  ROW_MIRROR = 0x140,
1759  ROW_HALF_MIRROR = 0x141,
1760  BCAST15 = 0x142,
1761  BCAST31 = 0x143,
1762  };
1763 
1764  auto kind = DppOp.getKind();
1765  auto permArgument = DppOp.getPermArgument();
1766  uint32_t DppCtrl = 0;
1767 
1768  switch (kind) {
1769 
1770  case DPPPerm::quad_perm:
1771  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1772  int32_t i = 0;
1773  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1774  uint32_t num = elem.getInt();
1775  DppCtrl |= num << (i * 2);
1776  i++;
1777  }
1778  }
1779  break;
1780  case DPPPerm::row_shl:
1781  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1782  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1783  }
1784  break;
1785  case DPPPerm::row_shr:
1786  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1787  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1788  }
1789  break;
1790  case DPPPerm::row_ror:
1791  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1792  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1793  }
1794  break;
1795  case DPPPerm::wave_shl:
1796  DppCtrl = DppCtrl::WAVE_SHL1;
1797  break;
1798  case DPPPerm::wave_shr:
1799  DppCtrl = DppCtrl::WAVE_SHR1;
1800  break;
1801  case DPPPerm::wave_rol:
1802  DppCtrl = DppCtrl::WAVE_ROL1;
1803  break;
1804  case DPPPerm::wave_ror:
1805  DppCtrl = DppCtrl::WAVE_ROR1;
1806  break;
1807  case DPPPerm::row_mirror:
1808  DppCtrl = DppCtrl::ROW_MIRROR;
1809  break;
1810  case DPPPerm::row_half_mirror:
1811  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1812  break;
1813  case DPPPerm::row_bcast_15:
1814  DppCtrl = DppCtrl::BCAST15;
1815  break;
1816  case DPPPerm::row_bcast_31:
1817  DppCtrl = DppCtrl::BCAST31;
1818  break;
1819  }
1820 
1821  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1822  // constants
1823  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1824  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1825  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1826 
1827  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1828  auto dppMovOp =
1829  ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
1830  rowMask, bankMask, boundCtrl);
1831 
1832  Value result = dppMovOp.getRes();
1833  if (srcType.getIntOrFloatBitWidth() < 32) {
1834  result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
1835  if (!llvm::isa<IntegerType>(srcType)) {
1836  result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
1837  }
1838  }
1839 
1840  // We are replacing the AMDGPU_DPPOp instruction with the new
1841  // ROCDL_DPPMovOp instruction
1842  rewriter.replaceOp(DppOp, ValueRange(result));
1843  return success();
1844  }
1845 };
1846 
1847 struct AMDGPUSwizzleBitModeLowering
1848  : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1850 
1851  LogicalResult
1852  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1853  ConversionPatternRewriter &rewriter) const override {
1854  Location loc = op.getLoc();
1855  Type i32 = rewriter.getI32Type();
1856  Value src = adaptor.getSrc();
1857  SmallVector<Value> decomposed =
1858  LLVM::decomposeValue(rewriter, loc, src, i32);
1859  unsigned andMask = op.getAndMask();
1860  unsigned orMask = op.getOrMask();
1861  unsigned xorMask = op.getXorMask();
1862 
1863  // bit 15 is 0 for the BitMode swizzle.
1864  // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1865  unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1866  Value maskValue = createI32Constant(rewriter, loc, mask);
1867  SmallVector<Value> swizzled;
1868  for (Value v : decomposed) {
1869  Value res =
1870  ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
1871  swizzled.emplace_back(res);
1872  }
1873 
1874  Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1875  rewriter.replaceOp(op, result);
1876  return success();
1877  }
1878 };
1879 
1880 struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
1882 
1883  AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
1884  : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
1885  Chipset chipset;
1886 
1887  LogicalResult
1888  matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
1889  ConversionPatternRewriter &rewriter) const override {
1890  if (chipset < kGfx950)
1891  return op->emitOpError("permlane_swap is only supported on gfx950+");
1892 
1893  Location loc = op.getLoc();
1894  Type i32 = rewriter.getI32Type();
1895  Value src = adaptor.getSrc();
1896  unsigned rowLength = op.getRowLength();
1897  bool fi = op.getFetchInactive();
1898  bool boundctrl = op.getBoundCtrl();
1899 
1900  SmallVector<Value> decomposed =
1901  LLVM::decomposeValue(rewriter, loc, src, i32);
1902 
1903  SmallVector<Value> permuted;
1904  for (Value v : decomposed) {
1905  Value res;
1906  Type i32pair = LLVM::LLVMStructType::getLiteral(
1907  rewriter.getContext(), {v.getType(), v.getType()});
1908 
1909  if (rowLength == 16)
1910  res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1911  boundctrl);
1912  else if (rowLength == 32)
1913  res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1914  boundctrl);
1915  else
1916  llvm_unreachable("unsupported row length");
1917 
1918  Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
1919  permuted.emplace_back(vdstNew);
1920  }
1921 
1922  Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
1923  rewriter.replaceOp(op, result);
1924  return success();
1925  }
1926 };
1927 
1928 struct ConvertAMDGPUToROCDLPass
1929  : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1930  using Base::Base;
1931 
1932  void runOnOperation() override {
1933  MLIRContext *ctx = &getContext();
1934  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1935  if (failed(maybeChipset)) {
1936  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1937  return signalPassFailure();
1938  }
1939 
1941  LLVMTypeConverter converter(ctx);
1942  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1943  LLVMConversionTarget target(getContext());
1944  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1945  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1946  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1947  if (failed(applyPartialConversion(getOperation(), target,
1948  std::move(patterns))))
1949  signalPassFailure();
1950  }
1951 };
1952 } // namespace
1953 
1955  TypeConverter &typeConverter) {
1956  typeConverter.addTypeAttributeConversion(
1957  [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1959  MLIRContext *ctx = as.getContext();
1960  Type i64 = IntegerType::get(ctx, 64);
1961  switch (as.getValue()) {
1962  case amdgpu::AddressSpace::FatRawBuffer:
1963  return IntegerAttr::get(i64, 7);
1964  case amdgpu::AddressSpace::BufferRsrc:
1965  return IntegerAttr::get(i64, 8);
1966  case amdgpu::AddressSpace::FatStructuredBuffer:
1967  return IntegerAttr::get(i64, 9);
1968  }
1970  });
1971 }
1972 
1975  Chipset chipset) {
1977  patterns
1978  .add<FatRawBufferCastLowering,
1979  RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1980  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1981  RawBufferOpLowering<RawBufferAtomicFaddOp,
1982  ROCDL::RawPtrBufferAtomicFaddOp>,
1983  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1984  ROCDL::RawPtrBufferAtomicFmaxOp>,
1985  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1986  ROCDL::RawPtrBufferAtomicSmaxOp>,
1987  RawBufferOpLowering<RawBufferAtomicUminOp,
1988  ROCDL::RawPtrBufferAtomicUminOp>,
1989  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1990  ROCDL::RawPtrBufferAtomicCmpSwap>,
1991  AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
1992  SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
1993  WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1994  PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1995  PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1996  TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
1997  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
1998 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
@ None
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getI16Type()
Definition: Builders.cpp:60
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:216
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
IntegerType getI4Type()
Definition: Builders.cpp:56
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IntegerType getI8Type()
Definition: Builders.cpp:58
FloatType getF64Type()
Definition: Builders.cpp:44
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
The general result of a type attribute conversion callback, allowing for early termination.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:439
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:400
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14