MLIR  21.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "../LLVMCommon/MemRefDescriptor.h"
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include <optional>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::amdgpu;
36 
37 // Define commonly used chipsets versions for convenience.
38 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
39 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
40 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
41 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
42 
43 /// Convert an unsigned number `val` to i32.
45  Location loc, Value val) {
46  IntegerType i32 = rewriter.getI32Type();
47  // Force check that `val` is of int type.
48  auto valTy = cast<IntegerType>(val.getType());
49  if (i32 == valTy)
50  return val;
51  return valTy.getWidth() > 32
52  ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
53  : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
54 }
55 
57  Location loc, int32_t value) {
58  Type i32 = rewriter.getI32Type();
59  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
60 }
61 
63  bool value) {
64  Type llvmI1 = rewriter.getI1Type();
65  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
66 }
67 
68 /// Returns the linear index used to access an element in the memref.
70  Location loc, MemRefDescriptor &memRefDescriptor,
71  ValueRange indices, ArrayRef<int64_t> strides) {
72  IntegerType i32 = rewriter.getI32Type();
73  Value index;
74  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
75  if (stride != 1) { // Skip if stride is 1.
76  Value strideValue =
77  ShapedType::isDynamic(stride)
78  ? convertUnsignedToI32(rewriter, loc,
79  memRefDescriptor.stride(rewriter, loc, i))
80  : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
81  increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
82  }
83  index =
84  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
85  }
86  return index ? index : createI32Constant(rewriter, loc, 0);
87 }
88 
89 /// Compute the contents of the `num_records` field for a given memref
90 /// descriptor - that is, the number of bytes that's one element past the
91 /// greatest possible valid index into the memref.
93  MemRefType memrefType,
94  MemRefDescriptor &memrefDescriptor,
95  ArrayRef<int64_t> strides,
96  uint32_t elementByteWidth) {
97  if (memrefType.hasStaticShape() &&
98  !llvm::any_of(strides, ShapedType::isDynamic)) {
99  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
100  ArrayRef<int64_t> shape = memrefType.getShape();
101  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
102  size = std::max(shape[i] * strides[i], size);
103  size = size * elementByteWidth;
104  assert(size < std::numeric_limits<uint32_t>::max() &&
105  "the memref buffer is too large");
106  return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
107  }
108  Value maxIndex;
109  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
110  Value size = memrefDescriptor.size(rewriter, loc, i);
111  Value stride = memrefDescriptor.stride(rewriter, loc, i);
112  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
113  maxIndex = maxIndex
114  ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
115  : maxThisDim;
116  }
117  Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
118  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
119  return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
120 }
121 
123  Value basePointer, Value numRecords,
124  bool boundsCheck, amdgpu::Chipset chipset,
125  Value cacheSwizzleStride = nullptr,
126  unsigned addressSpace = 8) {
127  // The stride value is generally 0. However, on MI-300 and onward, you can
128  // enable a cache swizzling mode by setting bit 14 of the stride field
129  // and setting that stride to a cache stride.
130  Type i16 = rewriter.getI16Type();
131  Value stride;
132  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
133  Value cacheStrideZext =
134  rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
135  Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
136  loc, i16, rewriter.getI16IntegerAttr(1 << 14));
137  stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
138  /*isDisjoint=*/true);
139  } else {
140  stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
141  rewriter.getI16IntegerAttr(0));
142  }
143  // Get the number of elements.
144  // Flag word:
145  // bits 0-11: dst sel, ignored by these intrinsics
146  // bits 12-14: data format (ignored, must be nonzero, 7=float)
147  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
148  // bit 19: In nested heap (0 here)
149  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
150  // bits 21-22: Index stride for swizzles (N/A)
151  // bit 23: Add thread ID (0)
152  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
153  // bits 25-26: Reserved (0)
154  // bit 27: Buffer is non-volatile (CDNA only)
155  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
156  // none, 3 = either swizzles or testing against offset field) RDNA only
157  // bits 30-31: Type (must be 0)
158  uint32_t flags = (7 << 12) | (4 << 15);
159  if (chipset.majorVersion >= 10) {
160  flags |= (1 << 24);
161  uint32_t oob = boundsCheck ? 3 : 2;
162  flags |= (oob << 28);
163  }
164  Value flagsConst = createI32Constant(rewriter, loc, flags);
165  Type rsrcType =
166  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
167  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
168  loc, rsrcType, basePointer, stride, numRecords, flagsConst);
169  return resource;
170 }
171 
172 namespace {
173 struct FatRawBufferCastLowering
174  : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
175  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
176  : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
177  chipset(chipset) {}
178 
179  Chipset chipset;
180 
181  LogicalResult
182  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
183  ConversionPatternRewriter &rewriter) const override {
184  Location loc = op.getLoc();
185  Value memRef = adaptor.getSource();
186  Value unconvertedMemref = op.getSource();
187  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
188  MemRefDescriptor descriptor(memRef);
189 
190  DataLayout dataLayout = DataLayout::closest(op);
191  int64_t elementByteWidth =
192  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
193 
194  int64_t unusedOffset = 0;
195  SmallVector<int64_t, 5> strideVals;
196  if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
197  return op.emitOpError("Can't lower non-stride-offset memrefs");
198 
199  Value numRecords = adaptor.getValidBytes();
200  if (!numRecords)
201  numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
202  strideVals, elementByteWidth);
203 
204  Value basePointer =
205  adaptor.getResetOffset()
206  ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
207  memrefType)
208  : descriptor.alignedPtr(rewriter, loc);
209 
210  Value offset = adaptor.getResetOffset()
211  ? rewriter.create<LLVM::ConstantOp>(
212  loc, getIndexType(), rewriter.getIndexAttr(0))
213  : descriptor.offset(rewriter, loc);
214 
215  bool hasSizes = memrefType.getRank() > 0;
216  // No need to unpack() and pack() all the individual sizes and strides,
217  // so we'll just extract the arrays.
218  Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
219  loc, descriptor, kSizePosInMemRefDescriptor)
220  : Value{};
221  Value strides = hasSizes
222  ? rewriter.create<LLVM::ExtractValueOp>(
223  loc, descriptor, kStridePosInMemRefDescriptor)
224  : Value{};
225 
226  Value fatPtr = makeBufferRsrc(
227  rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
228  chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
229 
231  rewriter, loc,
232  getTypeConverter()->convertType(op.getResult().getType()));
233  result = rewriter.create<LLVM::InsertValueOp>(
234  loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
235  result = rewriter.create<LLVM::InsertValueOp>(
236  loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
237  result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
239  if (hasSizes) {
240  result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
242  result = rewriter.create<LLVM::InsertValueOp>(
243  loc, result, strides, kStridePosInMemRefDescriptor);
244  }
245  rewriter.replaceOp(op, result);
246  return success();
247  }
248 };
249 
250 /// Define lowering patterns for raw buffer ops
251 template <typename GpuOp, typename Intrinsic>
252 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
253  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
254  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
255 
256  Chipset chipset;
257  static constexpr uint32_t maxVectorOpWidth = 128;
258 
259  LogicalResult
260  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
261  ConversionPatternRewriter &rewriter) const override {
262  Location loc = gpuOp.getLoc();
263  Value memref = adaptor.getMemref();
264  Value unconvertedMemref = gpuOp.getMemref();
265  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
266 
267  if (chipset.majorVersion < 9)
268  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
269 
270  Value storeData = adaptor.getODSOperands(0)[0];
271  if (storeData == memref) // no write component to this op
272  storeData = Value();
273  Type wantedDataType;
274  if (storeData)
275  wantedDataType = storeData.getType();
276  else
277  wantedDataType = gpuOp.getODSResults(0)[0].getType();
278 
279  Value atomicCmpData = Value();
280  // Operand index 1 of a load is the indices, trying to read them can crash.
281  if (storeData) {
282  Value maybeCmpData = adaptor.getODSOperands(1)[0];
283  if (maybeCmpData != memref)
284  atomicCmpData = maybeCmpData;
285  }
286 
287  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
288 
289  Type i32 = rewriter.getI32Type();
290 
291  // Get the type size in bytes.
292  DataLayout dataLayout = DataLayout::closest(gpuOp);
293  int64_t elementByteWidth =
294  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
295  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
296 
297  // If we want to load a vector<NxT> with total size <= 32
298  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
299  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
300  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
301  // so bitcast any floats to integers.
302  Type llvmBufferValType = llvmWantedDataType;
303  if (atomicCmpData) {
304  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
305  llvmBufferValType = this->getTypeConverter()->convertType(
306  rewriter.getIntegerType(floatType.getWidth()));
307  }
308  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
309  uint32_t vecLen = dataVector.getNumElements();
310  uint32_t elemBits =
311  dataLayout.getTypeSizeInBits(dataVector.getElementType());
312  uint32_t totalBits = elemBits * vecLen;
313  bool usePackedFp16 =
314  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
315  if (totalBits > maxVectorOpWidth)
316  return gpuOp.emitOpError(
317  "Total width of loads or stores must be no more than " +
318  Twine(maxVectorOpWidth) + " bits, but we call for " +
319  Twine(totalBits) +
320  " bits. This should've been caught in validation");
321  if (!usePackedFp16 && elemBits < 32) {
322  if (totalBits > 32) {
323  if (totalBits % 32 != 0)
324  return gpuOp.emitOpError("Load or store of more than 32-bits that "
325  "doesn't fit into words. Can't happen\n");
326  llvmBufferValType = this->typeConverter->convertType(
327  VectorType::get(totalBits / 32, i32));
328  } else {
329  llvmBufferValType = this->typeConverter->convertType(
330  rewriter.getIntegerType(totalBits));
331  }
332  }
333  }
334  if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
335  // Buffer intrinsics doesn't support 1-element vectors, cast them to
336  // scalars.
337  if (vecType.getNumElements() == 1)
338  llvmBufferValType = vecType.getElementType();
339  }
340 
342  if (storeData) {
343  if (llvmBufferValType != llvmWantedDataType) {
344  Value castForStore =
345  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
346  args.push_back(castForStore);
347  } else {
348  args.push_back(storeData);
349  }
350  }
351 
352  if (atomicCmpData) {
353  if (llvmBufferValType != llvmWantedDataType) {
354  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
355  loc, llvmBufferValType, atomicCmpData);
356  args.push_back(castForCmp);
357  } else {
358  args.push_back(atomicCmpData);
359  }
360  }
361 
362  // Construct buffer descriptor from memref, attributes
363  int64_t offset = 0;
364  SmallVector<int64_t, 5> strides;
365  if (failed(memrefType.getStridesAndOffset(strides, offset)))
366  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
367 
368  MemRefDescriptor memrefDescriptor(memref);
369 
370  Value ptr = memrefDescriptor.bufferPtr(
371  rewriter, loc, *this->getTypeConverter(), memrefType);
372  Value numRecords = getNumRecords(
373  rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
374  Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
375  adaptor.getBoundsCheck(), chipset);
376  args.push_back(resource);
377 
378  // Indexing (voffset)
379  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
380  adaptor.getIndices(), strides);
381  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
382  indexOffset && *indexOffset > 0) {
383  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
384  voffset =
385  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
386  : extraOffsetConst;
387  }
388  voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
389  args.push_back(voffset);
390 
391  // SGPR offset.
392  Value sgprOffset = adaptor.getSgprOffset();
393  if (!sgprOffset)
394  sgprOffset = createI32Constant(rewriter, loc, 0);
395  sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
396  args.push_back(sgprOffset);
397 
398  // bit 0: GLC = 0 (atomics drop value, less coherency)
399  // bits 1-2: SLC, DLC = 0 (similarly)
400  // bit 3: swizzled (0 for raw)
401  args.push_back(createI32Constant(rewriter, loc, 0));
402 
403  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
404  llvmBufferValType);
405  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
407  if (lowered->getNumResults() == 1) {
408  Value replacement = lowered->getResult(0);
409  if (llvmBufferValType != llvmWantedDataType) {
410  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
411  replacement);
412  }
413  rewriter.replaceOp(gpuOp, replacement);
414  } else {
415  rewriter.eraseOp(gpuOp);
416  }
417  return success();
418  }
419 };
420 
421 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
422  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
423  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
424 
425  Chipset chipset;
426 
427  LogicalResult
428  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
429  ConversionPatternRewriter &rewriter) const override {
430  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
431 
432  if (requiresInlineAsm) {
433  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
434  LLVM::AsmDialect::AD_ATT);
435  const char *asmStr =
436  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
437  const char *constraints = "";
438  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
439  op,
440  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
441  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
442  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
443  /*operand_attrs=*/ArrayAttr());
444  return success();
445  }
446  if (chipset.majorVersion < 12) {
447  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
448  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
449  // Left in place in case someone disables the inline ASM path or future
450  // chipsets use the same bit pattern.
451  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
452 
453  int32_t ldsOnlyBits;
454  if (chipset.majorVersion == 11)
455  ldsOnlyBits = ldsOnlyBitsGfx11;
456  else if (chipset.majorVersion == 10)
457  ldsOnlyBits = ldsOnlyBitsGfx10;
458  else if (chipset.majorVersion <= 9)
459  ldsOnlyBits = ldsOnlyBitsGfx6789;
460  else
461  return op.emitOpError(
462  "don't know how to lower this for chipset major version")
463  << chipset.majorVersion;
464 
465  Location loc = op->getLoc();
466  rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
467  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
468  } else {
469  Location loc = op->getLoc();
470  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
471  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
472  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
473  }
474 
475  return success();
476  }
477 };
478 
479 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
480  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
481  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
482 
483  Chipset chipset;
484 
485  LogicalResult
486  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
487  ConversionPatternRewriter &rewriter) const override {
488  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
489  (uint32_t)op.getOpts());
490  return success();
491  }
492 };
493 
494 } // namespace
495 
496 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
497 /// and LLVM AMDGPU intrinsics convention.
498 ///
499 /// Specifically:
500 /// 1. If the element type is bfloat16, bitcast it to i16.
501 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
502 /// instead, which is what the f8f6f4 intrinsics use.
503 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
504 /// integer.
505 ///
506 /// Note that the type of `input` has already been LLVM type converted:
507 /// therefore 8-bit and smaller floats are represented as their corresponding
508 /// `iN` integers.
510  Location loc, Value input) {
511  Type inputType = input.getType();
512  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
513  if (vectorType.getElementType().isBF16())
514  return rewriter.create<LLVM::BitcastOp>(
515  loc, vectorType.clone(rewriter.getI16Type()), input);
516  if (vectorType.getElementType().isInteger(8) &&
517  vectorType.getNumElements() <= 8)
518  return rewriter.create<LLVM::BitcastOp>(
519  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
520  if (isa<IntegerType>(vectorType.getElementType()) &&
521  vectorType.getElementTypeBitWidth() <= 8) {
522  int64_t numWords = llvm::divideCeil(
523  vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
524  32);
525  return rewriter.create<LLVM::BitcastOp>(
526  loc, VectorType::get(numWords, rewriter.getI32Type()), input);
527  }
528  }
529  return input;
530 }
531 
532 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
533 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
534 ///
535 /// Specifically:
536 /// 1. If `input` is a i8 value, zero extend it to i32
537 /// 2. If `input` is a vector of length 4 and type i8, cast it to i32
538 ///
539 /// Note that the type of `input` has already been LLVM type converted:
540 /// therefore 8-bit and smaller floats are represented as their corresponding
541 /// `iN` integers.
543  Location loc, Value input) {
544  Type inputType = input.getType();
545  Type outputType = rewriter.getI32Type();
546  if (auto intType = dyn_cast<IntegerType>(inputType))
547  return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
548  return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
549 }
550 
551 /// Push an input operand. If it is a float type, nothing to do. If it is
552 /// an integer type, then we need to also push its signdness (1 for signed, 0
553 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
554 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
555 /// We also need to convert bfloat inputs to i16 to account for the bfloat
556 /// intrinsics having been defined before the AMD backend supported bfloat. We
557 /// similarly need to pack 8-bit float types into integers as if they were i8
558 /// (which they are for the backend's purposes).
560  Location loc,
561  const TypeConverter *typeConverter,
562  bool isUnsigned, Value llvmInput,
563  Value mlirInput,
564  SmallVector<Value, 4> &operands) {
565  Type inputType = llvmInput.getType();
566  auto vectorType = dyn_cast<VectorType>(inputType);
567  if (!vectorType) {
568  operands.push_back(llvmInput);
569  return;
570  }
571  Type elemType = vectorType.getElementType();
572 
573  if (elemType.isBF16())
574  llvmInput = rewriter.create<LLVM::BitcastOp>(
575  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
576  if (elemType.getIntOrFloatBitWidth() > 8) {
577  operands.push_back(llvmInput);
578  return;
579  }
580 
581  // We need to check the type of the input before conversion to properly test
582  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
583  // fp8/int8 information is lost during the conversion process.
584  auto mlirInputType = cast<VectorType>(mlirInput.getType());
585  bool isInputInteger = mlirInputType.getElementType().isInteger();
586  if (isInputInteger) {
587  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
588  bool localIsUnsigned = isUnsigned;
589  if (elemType.isUnsignedInteger()) {
590  localIsUnsigned = true;
591  } else if (elemType.isSignedInteger()) {
592  localIsUnsigned = false;
593  }
594  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
595  operands.push_back(sign);
596  }
597 
598  int64_t numBits =
599  vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
600  Type i32 = rewriter.getI32Type();
601  Type intrinsicInType = numBits <= 32
602  ? (Type)rewriter.getIntegerType(numBits)
603  : (Type)VectorType::get(numBits / 32, i32);
604  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
605  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
606  loc, llvmIntrinsicInType, llvmInput);
607  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
608  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
609  // Add in the zeros here.
610  if (numBits < 32)
611  castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
612  operands.push_back(castInput);
613 }
614 
615 /// Push the output operand. For many cases this is only pushing the output in
616 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
617 /// since the same numbers of VGPRs is used, we need to decide if to store the
618 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
619 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
620 /// be stored it in the upper part. The subwordOffset must not be set for gfx12,
621 /// as the instructions have been changed to return fewer registers instead.
623  Location loc,
624  const TypeConverter *typeConverter,
625  Value output, int32_t subwordOffset,
626  bool clamp, SmallVector<Value, 4> &operands) {
627  Type inputType = output.getType();
628  auto vectorType = dyn_cast<VectorType>(inputType);
629  Type elemType = vectorType.getElementType();
630  if (elemType.isBF16())
631  output = rewriter.create<LLVM::BitcastOp>(
632  loc, vectorType.clone(rewriter.getI16Type()), output);
633  operands.push_back(output);
634  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
635  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
636  } else if (elemType.isInteger(32)) {
637  operands.push_back(createI1Constant(rewriter, loc, clamp));
638  }
639 }
640 
641 /// Return true if `type` is the E5M2 variant of an 8-bit float that is
642 /// supported by the `_bf8` instructions on the given `chipset`.
643 static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
644  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
645  (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
646 }
647 
648 /// Return true if `type` is the E4M3FN variant of an 8-bit float that is
649 /// supported by the `_fp8` instructions on the given `chipset`.
650 static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
651  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
652  (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
653 }
654 
655 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
656 /// if one exists. This includes checking to ensure the intrinsic is supported
657 /// on the architecture you are compiling for.
658 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
659  Chipset chipset) {
660  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
661  b = mfma.getBlocks();
662  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
663  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
664 
665  if (sourceElem.isF32() && destElem.isF32()) {
666  if (mfma.getReducePrecision() && chipset >= kGfx942) {
667  if (m == 32 && n == 32 && k == 4 && b == 1)
668  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
669  if (m == 16 && n == 16 && k == 8 && b == 1)
670  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
671  }
672  if (m == 32 && n == 32 && k == 1 && b == 2)
673  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
674  if (m == 16 && n == 16 && k == 1 && b == 4)
675  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
676  if (m == 4 && n == 4 && k == 1 && b == 16)
677  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
678  if (m == 32 && n == 32 && k == 2 && b == 1)
679  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
680  if (m == 16 && n == 16 && k == 4 && b == 1)
681  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
682  }
683 
684  if (sourceElem.isF16() && destElem.isF32()) {
685  if (chipset >= kGfx950) {
686  if (m == 32 && n == 32 && k == 16 && b == 1)
687  return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
688  if (m == 16 && n == 16 && k == 32 && b == 1)
689  return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
690  }
691  if (m == 32 && n == 32 && k == 4 && b == 2)
692  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
693  if (m == 16 && n == 16 && k == 4 && b == 4)
694  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
695  if (m == 4 && n == 4 && k == 4 && b == 16)
696  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
697  if (m == 32 && n == 32 && k == 8 && b == 1)
698  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
699  if (m == 16 && n == 16 && k == 16 && b == 1)
700  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
701  }
702 
703  if (sourceElem.isBF16() && destElem.isF32()) {
704  if (chipset >= kGfx950) {
705  if (m == 32 && n == 32 && k == 16 && b == 1)
706  return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
707  if (m == 16 && n == 16 && k == 32 && b == 1)
708  return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
709  }
710  if (chipset >= kGfx90a) {
711  if (m == 32 && n == 32 && k == 4 && b == 2)
712  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
713  if (m == 16 && n == 16 && k == 4 && b == 4)
714  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
715  if (m == 4 && n == 4 && k == 4 && b == 16)
716  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
717  if (m == 32 && n == 32 && k == 8 && b == 1)
718  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
719  if (m == 16 && n == 16 && k == 16 && b == 1)
720  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
721  }
722  if (m == 32 && n == 32 && k == 2 && b == 2)
723  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
724  if (m == 16 && n == 16 && k == 2 && b == 4)
725  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
726  if (m == 4 && n == 4 && k == 2 && b == 16)
727  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
728  if (m == 32 && n == 32 && k == 4 && b == 1)
729  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
730  if (m == 16 && n == 16 && k == 8 && b == 1)
731  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
732  }
733 
734  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
735  if (chipset >= kGfx950) {
736  if (m == 32 && n == 32 && k == 32 && b == 1)
737  return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
738  if (m == 16 && n == 16 && k == 64 && b == 1)
739  return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
740  }
741  if (m == 32 && n == 32 && k == 4 && b == 2)
742  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
743  if (m == 16 && n == 16 && k == 4 && b == 4)
744  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
745  if (m == 4 && n == 4 && k == 4 && b == 16)
746  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
747  if (m == 32 && n == 32 && k == 8 && b == 1)
748  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
749  if (m == 16 && n == 16 && k == 16 && b == 1)
750  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
751  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
752  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
753  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
754  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
755  }
756 
757  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
758  if (m == 16 && n == 16 && k == 4 && b == 1)
759  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
760  if (m == 4 && n == 4 && k == 4 && b == 4)
761  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
762  }
763 
764  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
765  // Known to be correct because there are no scalar f8 instructions and
766  // because a length mismatch will have been caught by the verifier.
767  Type sourceBElem =
768  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
769  if (m == 16 && n == 16 && k == 32 && b == 1) {
770  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
771  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
772  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
773  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
774  }
775  if (m == 32 && n == 32 && k == 16 && b == 1) {
776  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
777  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
778  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
779  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
780  }
781  }
782 
783  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
784  Type sourceBElem =
785  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
786  if (m == 16 && n == 16 && k == 32 && b == 1) {
787  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
788  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
789  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
790  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
791  }
792  if (m == 32 && n == 32 && k == 16 && b == 1) {
793  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
794  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
795  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
796  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
797  }
798  }
799 
800  return std::nullopt;
801 }
802 
803 static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
805  .Case([](Float8E4M3FNType) { return 0u; })
806  .Case([](Float8E5M2Type) { return 1u; })
807  .Case([](Float6E2M3FNType) { return 2u; })
808  .Case([](Float6E3M2FNType) { return 3u; })
809  .Case([](Float4E2M1FNType) { return 4u; })
810  .Default([](Type) { return std::nullopt; });
811 }
812 
813 /// If there is a scaled MFMA instruction for the input element types `aType`
814 /// and `bType`, output type `destType`, problem size M, N, K, and B (number of
815 /// blocks) on the given `chipset`, return a tuple consisting of the
816 /// OperationName of the intrinsic and the type codes that need to be passed to
817 /// that intrinsic. Note that this is also used to implement some un-scaled
818 /// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
819 /// MFMA with a scale of 0.
820 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
821 mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
822  uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
823  aType = getElementTypeOrSelf(aType);
824  bType = getElementTypeOrSelf(bType);
825  destType = getElementTypeOrSelf(destType);
826 
827  if (chipset < kGfx950)
828  return std::nullopt;
829  if (!isa<Float32Type>(destType))
830  return std::nullopt;
831 
832  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
833  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
834  if (!aTypeCode || !bTypeCode)
835  return std::nullopt;
836 
837  if (m == 32 && n == 32 && k == 64 && b == 1)
838  return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
839  *aTypeCode, *bTypeCode};
840  if (m == 16 && n == 16 && k == 128 && b == 1)
841  return std::tuple{
842  ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
843  *bTypeCode};
844 
845  return std::nullopt;
846 }
847 
848 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
849 mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
851  mfma.getSourceA().getType(), mfma.getSourceB().getType(),
852  mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
853  mfma.getBlocks(), chipset);
854 }
855 
856 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
857 mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
858  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
859  smfma.getSourceB().getType(),
860  smfma.getDestC().getType(), smfma.getM(),
861  smfma.getN(), smfma.getK(), 1u, chipset);
862 }
863 
864 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
865 /// if one exists. This includes checking to ensure the intrinsic is supported
866 /// on the architecture you are compiling for.
867 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
868  Chipset chipset) {
869  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
870  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
871  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
872  auto elemSourceType = sourceVectorType.getElementType();
873  auto elemBSourceType = sourceBVectorType.getElementType();
874  auto elemDestType = destVectorType.getElementType();
875 
876  if (elemSourceType.isF16() && elemDestType.isF32())
877  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
878  if (elemSourceType.isBF16() && elemDestType.isF32())
879  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
880  if (elemSourceType.isF16() && elemDestType.isF16())
881  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
882  if (elemSourceType.isBF16() && elemDestType.isBF16())
883  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
884  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
885  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
886  if (chipset.majorVersion == 11) {
887  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
888  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
889  }
890  if (chipset.majorVersion >= 12) {
891  if (isa<Float8E4M3FNType>(elemSourceType) &&
892  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
893  return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
894  if (isa<Float8E4M3FNType>(elemSourceType) &&
895  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
896  return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
897  if (isa<Float8E5M2Type>(elemSourceType) &&
898  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
899  return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
900  if (isa<Float8E5M2Type>(elemSourceType) &&
901  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
902  return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
903  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
904  bool isWave64 = destVectorType.getNumElements() == 4;
905  // This is the ambiguous case. 8 inputs to the wave64 version means that
906  // we want the 16x16x32 version, but for wave32 they mean the short form.
907  bool has8Inputs = sourceVectorType.getNumElements() == 8;
908  if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
909  return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
910  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
911  }
912  }
913  return std::nullopt;
914 }
915 
916 namespace {
917 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
918  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
919  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
920 
921  Chipset chipset;
922 
923  LogicalResult
924  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
925  ConversionPatternRewriter &rewriter) const override {
926  Location loc = op.getLoc();
927  Type outType = typeConverter->convertType(op.getDestD().getType());
928  Type intrinsicOutType = outType;
929  if (auto outVecType = dyn_cast<VectorType>(outType))
930  if (outVecType.getElementType().isBF16())
931  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
932 
933  if (chipset.majorVersion != 9 || chipset < kGfx908)
934  return op->emitOpError("MFMA only supported on gfx908+");
935  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
936  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
937  if (chipset < kGfx942)
938  return op.emitOpError("negation unsupported on older than gfx942");
939  getBlgpField |=
940  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
941  }
942  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
943  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
944  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
945  if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
946  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
947 
948  bool isScaled =
949  !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
950  if (isScaled &&
951  (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
952  return op.emitOpError(
953  "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
954  "be scaled as those fields are used for type information");
955  }
956 
957  StringRef intrinsicName =
958  isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
959  OperationState loweredOp(loc, intrinsicName);
960  loweredOp.addTypes(intrinsicOutType);
961  loweredOp.addOperands(
962  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
963  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
964  adaptor.getDestC()});
965  if (isScaled) {
966  Value zero = createI32Constant(rewriter, loc, 0);
967  auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
968  loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
969  createI32Constant(rewriter, loc, bTypeCode),
970  /*scale A byte=*/zero, /*scale A=*/zero,
971  /*scale B byte=*/zero, /*scale B=*/zero});
972  } else {
973  loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
974  createI32Constant(rewriter, loc, op.getAbid()),
975  createI32Constant(rewriter, loc, getBlgpField)});
976  };
977  Value lowered = rewriter.create(loweredOp)->getResult(0);
978  if (outType != intrinsicOutType)
979  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
980  rewriter.replaceOp(op, lowered);
981  return success();
982  }
983 };
984 
985 struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
986  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
987  : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
988 
989  Chipset chipset;
990 
991  LogicalResult
992  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
993  ConversionPatternRewriter &rewriter) const override {
994  Location loc = op.getLoc();
995  Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
996 
997  if (chipset.majorVersion != 9 || chipset < kGfx950)
998  return op->emitOpError("scaled MFMA only supported on gfx908+");
999  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1000  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1001  if (!maybeScaledIntrinsic.has_value())
1002  return op.emitOpError(
1003  "no intrinsic matching scaled MFMA size on given chipset");
1004 
1005  auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1006  OperationState loweredOp(loc, intrinsicName);
1007  loweredOp.addTypes(intrinsicOutType);
1008  loweredOp.addOperands(
1009  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1010  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1011  adaptor.getDestC()});
1012  Value scalesIdxA =
1013  createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1014  Value scalesIdxB =
1015  createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1016  loweredOp.addOperands(
1017  {createI32Constant(rewriter, loc, aTypeCode),
1018  createI32Constant(rewriter, loc, bTypeCode),
1019  /*scales idx A=*/scalesIdxA,
1020  /*scales A*/
1021  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1022  /*scales idx B=*/scalesIdxB,
1023  /*scales B*/
1024  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1025  Value lowered = rewriter.create(loweredOp)->getResult(0);
1026  rewriter.replaceOp(op, lowered);
1027  return success();
1028  }
1029 };
1030 
1031 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1032  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1033  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1034 
1035  Chipset chipset;
1036 
1037  LogicalResult
1038  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1039  ConversionPatternRewriter &rewriter) const override {
1040  Location loc = op.getLoc();
1041  auto outType =
1042  typeConverter->convertType<VectorType>(op.getDestD().getType());
1043  if (!outType)
1044  return rewriter.notifyMatchFailure(op, "type conversion failed");
1045 
1046  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1047  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1048 
1049  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1050  // need to bitcast bfloats to i16 and then bitcast them back.
1051  VectorType rawOutType = outType;
1052  if (outType.getElementType().isBF16())
1053  rawOutType = outType.clone(rewriter.getI16Type());
1054 
1055  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1056 
1057  if (!maybeIntrinsic.has_value())
1058  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1059 
1060  if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1061  return op.emitOpError("subwordOffset not supported on gfx12+");
1062 
1063  OperationState loweredOp(loc, *maybeIntrinsic);
1064  loweredOp.addTypes(rawOutType);
1065 
1066  SmallVector<Value, 4> operands;
1067  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1068  adaptor.getSourceA(), op.getSourceA(), operands);
1069  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1070  adaptor.getSourceB(), op.getSourceB(), operands);
1071  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1072  op.getSubwordOffset(), op.getClamp(), operands);
1073 
1074  loweredOp.addOperands(operands);
1075  Operation *lowered = rewriter.create(loweredOp);
1076 
1077  Operation *maybeCastBack = lowered;
1078  if (rawOutType != outType)
1079  maybeCastBack =
1080  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
1081  rewriter.replaceOp(op, maybeCastBack->getResults());
1082 
1083  return success();
1084  }
1085 };
1086 
1087 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1088  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1089  : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1090 
1091  Chipset chipset;
1092 
1093  LogicalResult
1094  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1095  ConversionPatternRewriter &rewriter) const override {
1096  if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1097  return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1098 
1099  Location loc = op.getLoc();
1100 
1101  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1102  auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
1103 
1104  // TODO: instead of only transfering one element per thread, we could
1105  // augment it to transfer multiple elements per thread by issuing multiple
1106  // `global_load_lds` instructions.
1107  Type transferType = op.getTransferType();
1108  size_t loadWidth = [&]() -> size_t {
1109  if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1110  return transferVectorType.getNumElements() *
1111  (transferVectorType.getElementTypeBitWidth() / 8);
1112  }
1113  return transferType.getIntOrFloatBitWidth() / 8;
1114  }();
1115 
1116  // Currently only 1, 2, and 4 byte loads are supported.
1117  if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1118  return op.emitOpError("chipset unsupported element size");
1119 
1120  Value srcPtr =
1121  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1122  (adaptor.getSrcIndices()));
1123  Value dstPtr =
1124  getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1125  (adaptor.getDstIndices()));
1126 
1127  rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1128  op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1129  /*offset=*/rewriter.getI32IntegerAttr(0),
1130  /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1131  ArrayAttr{});
1132 
1133  return success();
1134  }
1135 };
1136 
1137 namespace {
1138 struct ExtPackedFp8OpLowering final
1139  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1140  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1141  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1142  chipset(chipset) {}
1143  Chipset chipset;
1144 
1145  LogicalResult
1146  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1147  ConversionPatternRewriter &rewriter) const override;
1148 };
1149 
1150 struct PackedTrunc2xFp8OpLowering final
1151  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1152  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1153  Chipset chipset)
1154  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1155  chipset(chipset) {}
1156  Chipset chipset;
1157 
1158  LogicalResult
1159  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1160  ConversionPatternRewriter &rewriter) const override;
1161 };
1162 
1163 struct PackedStochRoundFp8OpLowering final
1164  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1165  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1166  Chipset chipset)
1167  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1168  chipset(chipset) {}
1169  Chipset chipset;
1170 
1171  LogicalResult
1172  matchAndRewrite(PackedStochRoundFp8Op op,
1173  PackedStochRoundFp8OpAdaptor adaptor,
1174  ConversionPatternRewriter &rewriter) const override;
1175 };
1176 } // end namespace
1177 
1178 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1179  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1180  ConversionPatternRewriter &rewriter) const {
1181  Location loc = op.getLoc();
1182  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1183  return rewriter.notifyMatchFailure(
1184  loc, "Fp8 conversion instructions are not available on target "
1185  "architecture and their emulation is not implemented");
1186  Type v4i8 =
1187  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1188  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1189  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1190 
1191  Value source = adaptor.getSource();
1192  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1193  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1194  Type sourceElemType = getElementTypeOrSelf(op.getSource());
1195  // Extend to a v4i8
1196  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1197  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
1198  if (!sourceVecType) {
1199  longVec = rewriter.create<LLVM::InsertElementOp>(
1200  loc, longVec, source, createI32Constant(rewriter, loc, 0));
1201  } else {
1202  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1203  Value idx = createI32Constant(rewriter, loc, i);
1204  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1205  longVec =
1206  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1207  }
1208  }
1209  source = longVec;
1210  }
1211  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1212  if (resultVecType) {
1213  Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
1214  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1215  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1216  wordSel);
1217  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1218  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1219  wordSel);
1220  }
1221  } else {
1222  Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
1223  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1224  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1225  byteSel);
1226  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1227  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1228  byteSel);
1229  }
1230  }
1231  return success();
1232 }
1233 
1234 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1235  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1236  ConversionPatternRewriter &rewriter) const {
1237  Location loc = op.getLoc();
1238  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1239  return rewriter.notifyMatchFailure(
1240  loc, "Fp8 conversion instructions are not available on target "
1241  "architecture and their emulation is not implemented");
1242  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1243 
1244  Type resultType = op.getResult().getType();
1245  Type resultElemType = getElementTypeOrSelf(resultType);
1246 
1247  Value sourceA = adaptor.getSourceA();
1248  Value sourceB = adaptor.getSourceB();
1249  if (!sourceB)
1250  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
1251  Value existing = adaptor.getExisting();
1252  if (existing)
1253  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1254  else
1255  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1256  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
1257 
1258  Value result;
1259  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1260  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1261  existing, wordSel);
1262  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1263  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1264  existing, wordSel);
1265 
1266  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1267  op, getTypeConverter()->convertType(resultType), result);
1268  return success();
1269 }
1270 
1271 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1272  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1273  ConversionPatternRewriter &rewriter) const {
1274  Location loc = op.getLoc();
1275  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1276  return rewriter.notifyMatchFailure(
1277  loc, "Fp8 conversion instructions are not available on target "
1278  "architecture and their emulation is not implemented");
1279  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1280 
1281  Type resultType = op.getResult().getType();
1282  Type resultElemType = getElementTypeOrSelf(resultType);
1283 
1284  Value source = adaptor.getSource();
1285  Value stoch = adaptor.getStochiasticParam();
1286  Value existing = adaptor.getExisting();
1287  if (existing)
1288  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1289  else
1290  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1291  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
1292 
1293  Value result;
1294  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1295  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
1296  existing, byteSel);
1297  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1298  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
1299  existing, byteSel);
1300 
1301  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1302  op, getTypeConverter()->convertType(resultType), result);
1303  return success();
1304 }
1305 
1306 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1307 // operation into the corresponding ROCDL instructions.
1308 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1309  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1310  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1311  Chipset chipset;
1312 
1313  LogicalResult
1314  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1315  ConversionPatternRewriter &rewriter) const override {
1316 
1317  // Convert the source operand to the corresponding LLVM type
1318  Location loc = DppOp.getLoc();
1319  Value src = adaptor.getSrc();
1320  Value old = adaptor.getOld();
1321  Type srcType = src.getType();
1322  Type oldType = old.getType();
1323  Type llvmType = nullptr;
1324  if (srcType.getIntOrFloatBitWidth() < 32) {
1325  llvmType = rewriter.getI32Type();
1326  } else if (isa<FloatType>(srcType)) {
1327  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1328  ? rewriter.getF32Type()
1329  : rewriter.getF64Type();
1330  } else if (isa<IntegerType>(srcType)) {
1331  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1332  ? rewriter.getI32Type()
1333  : rewriter.getI64Type();
1334  }
1335  auto llvmSrcIntType = typeConverter->convertType(
1336  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1337 
1338  // If the source type is less of 32, use bitcast to convert it to i32.
1339  auto convertOperand = [&](Value operand, Type operandType) {
1340  if (operandType.getIntOrFloatBitWidth() <= 16) {
1341  if (llvm::isa<FloatType>(operandType)) {
1342  operand =
1343  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1344  }
1345  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1346  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1347  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
1348  operand = rewriter.create<LLVM::InsertElementOp>(
1349  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
1350  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
1351  }
1352  return operand;
1353  };
1354 
1355  src = convertOperand(src, srcType);
1356  old = convertOperand(old, oldType);
1357 
1358  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1359  enum DppCtrl : unsigned {
1360  ROW_SHL0 = 0x100,
1361  ROW_SHR0 = 0x110,
1362  ROW_ROR0 = 0x120,
1363  WAVE_SHL1 = 0x130,
1364  WAVE_ROL1 = 0x134,
1365  WAVE_SHR1 = 0x138,
1366  WAVE_ROR1 = 0x13C,
1367  ROW_MIRROR = 0x140,
1368  ROW_HALF_MIRROR = 0x141,
1369  BCAST15 = 0x142,
1370  BCAST31 = 0x143,
1371  };
1372 
1373  auto kind = DppOp.getKind();
1374  auto permArgument = DppOp.getPermArgument();
1375  uint32_t DppCtrl = 0;
1376 
1377  switch (kind) {
1378 
1379  case DPPPerm::quad_perm:
1380  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1381  int32_t i = 0;
1382  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1383  uint32_t num = elem.getInt();
1384  DppCtrl |= num << (i * 2);
1385  i++;
1386  }
1387  }
1388  break;
1389  case DPPPerm::row_shl:
1390  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1391  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1392  }
1393  break;
1394  case DPPPerm::row_shr:
1395  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1396  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1397  }
1398  break;
1399  case DPPPerm::row_ror:
1400  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1401  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1402  }
1403  break;
1404  case DPPPerm::wave_shl:
1405  DppCtrl = DppCtrl::WAVE_SHL1;
1406  break;
1407  case DPPPerm::wave_shr:
1408  DppCtrl = DppCtrl::WAVE_SHR1;
1409  break;
1410  case DPPPerm::wave_rol:
1411  DppCtrl = DppCtrl::WAVE_ROL1;
1412  break;
1413  case DPPPerm::wave_ror:
1414  DppCtrl = DppCtrl::WAVE_ROR1;
1415  break;
1416  case DPPPerm::row_mirror:
1417  DppCtrl = DppCtrl::ROW_MIRROR;
1418  break;
1419  case DPPPerm::row_half_mirror:
1420  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1421  break;
1422  case DPPPerm::row_bcast_15:
1423  DppCtrl = DppCtrl::BCAST15;
1424  break;
1425  case DPPPerm::row_bcast_31:
1426  DppCtrl = DppCtrl::BCAST31;
1427  break;
1428  }
1429 
1430  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1431  // constants
1432  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1433  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1434  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1435 
1436  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1437  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1438  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1439 
1440  Value result = dppMovOp.getRes();
1441  if (srcType.getIntOrFloatBitWidth() < 32) {
1442  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1443  if (!llvm::isa<IntegerType>(srcType)) {
1444  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1445  }
1446  }
1447 
1448  // We are replacing the AMDGPU_DPPOp instruction with the new
1449  // ROCDL_DPPMovOp instruction
1450  rewriter.replaceOp(DppOp, ValueRange(result));
1451  return success();
1452  }
1453 };
1454 
1455 struct AMDGPUSwizzleBitModeLowering
1456  : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1458 
1459  LogicalResult
1460  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1461  ConversionPatternRewriter &rewriter) const override {
1462  Location loc = op.getLoc();
1463  Type i32 = rewriter.getI32Type();
1464  Value src = adaptor.getSrc();
1465  SmallVector<Value> decomposed =
1466  LLVM::decomposeValue(rewriter, loc, src, i32);
1467  unsigned andMask = op.getAndMask();
1468  unsigned orMask = op.getOrMask();
1469  unsigned xorMask = op.getXorMask();
1470 
1471  // bit 15 is 0 for the BitMode swizzle.
1472  // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1473  unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1474  Value maskValue = createI32Constant(rewriter, loc, mask);
1475  SmallVector<Value> swizzled;
1476  for (Value v : decomposed) {
1477  Value res =
1478  rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1479  swizzled.emplace_back(res);
1480  }
1481 
1482  Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1483  rewriter.replaceOp(op, result);
1484  return success();
1485  }
1486 };
1487 
1488 struct ConvertAMDGPUToROCDLPass
1489  : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1490  using Base::Base;
1491 
1492  void runOnOperation() override {
1493  MLIRContext *ctx = &getContext();
1494  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1495  if (failed(maybeChipset)) {
1496  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1497  return signalPassFailure();
1498  }
1499 
1501  LLVMTypeConverter converter(ctx);
1502  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1503  LLVMConversionTarget target(getContext());
1504  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1505  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1506  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1507  if (failed(applyPartialConversion(getOperation(), target,
1508  std::move(patterns))))
1509  signalPassFailure();
1510  }
1511 };
1512 } // namespace
1513 
1515  TypeConverter &typeConverter) {
1516  typeConverter.addTypeAttributeConversion(
1517  [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1519  MLIRContext *ctx = as.getContext();
1520  Type i64 = IntegerType::get(ctx, 64);
1521  switch (as.getValue()) {
1522  case amdgpu::AddressSpace::FatRawBuffer:
1523  return IntegerAttr::get(i64, 7);
1524  case amdgpu::AddressSpace::BufferRsrc:
1525  return IntegerAttr::get(i64, 8);
1526  case amdgpu::AddressSpace::FatStructuredBuffer:
1527  return IntegerAttr::get(i64, 9);
1528  }
1530  });
1531 }
1532 
1535  Chipset chipset) {
1537  patterns
1538  .add<FatRawBufferCastLowering,
1539  RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1540  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1541  RawBufferOpLowering<RawBufferAtomicFaddOp,
1542  ROCDL::RawPtrBufferAtomicFaddOp>,
1543  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1544  ROCDL::RawPtrBufferAtomicFmaxOp>,
1545  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1546  ROCDL::RawPtrBufferAtomicSmaxOp>,
1547  RawBufferOpLowering<RawBufferAtomicUminOp,
1548  ROCDL::RawPtrBufferAtomicUminOp>,
1549  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1550  ROCDL::RawPtrBufferAtomicCmpSwap>,
1551  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1552  MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1553  ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1554  PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1555  chipset);
1556  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
1557 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI16Type()
Definition: Builders.cpp:61
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:213
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:53
IntegerType getI8Type()
Definition: Builders.cpp:59
FloatType getF64Type()
Definition: Builders.cpp:45
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:182
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:188
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
The general result of a type attribute conversion callback, allowing for early termination.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:478
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:439
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14