MLIR  21.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include <optional>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::amdgpu;
32 
33 /// Convert an unsigned number `val` to i32.
35  Location loc, Value val) {
36  IntegerType i32 = rewriter.getI32Type();
37  // Force check that `val` is of int type.
38  auto valTy = cast<IntegerType>(val.getType());
39  if (i32 == valTy)
40  return val;
41  return valTy.getWidth() > 32
42  ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
43  : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
44 }
45 
47  Location loc, int32_t value) {
48  Type i32 = rewriter.getI32Type();
49  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
50 }
51 
53  bool value) {
54  Type llvmI1 = rewriter.getI1Type();
55  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
56 }
57 
58 /// Returns the linear index used to access an element in the memref.
60  Location loc, MemRefDescriptor &memRefDescriptor,
61  ValueRange indices, ArrayRef<int64_t> strides) {
62  IntegerType i32 = rewriter.getI32Type();
63  Value index;
64  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
65  if (stride != 1) { // Skip if stride is 1.
66  Value strideValue =
67  ShapedType::isDynamic(stride)
68  ? convertUnsignedToI32(rewriter, loc,
69  memRefDescriptor.stride(rewriter, loc, i))
70  : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
71  increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
72  }
73  index =
74  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
75  }
76  return index ? index : createI32Constant(rewriter, loc, 0);
77 }
78 
79 namespace {
80 // Define commonly used chipsets versions for convenience.
81 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
82 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
83 constexpr Chipset kGfx940 = Chipset(9, 4, 0);
84 
85 /// Define lowering patterns for raw buffer ops
86 template <typename GpuOp, typename Intrinsic>
87 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
88  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
89  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
90 
91  Chipset chipset;
92  static constexpr uint32_t maxVectorOpWidth = 128;
93 
94  LogicalResult
95  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
96  ConversionPatternRewriter &rewriter) const override {
97  Location loc = gpuOp.getLoc();
98  Value memref = adaptor.getMemref();
99  Value unconvertedMemref = gpuOp.getMemref();
100  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
101 
102  if (chipset.majorVersion < 9)
103  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
104 
105  Value storeData = adaptor.getODSOperands(0)[0];
106  if (storeData == memref) // no write component to this op
107  storeData = Value();
108  Type wantedDataType;
109  if (storeData)
110  wantedDataType = storeData.getType();
111  else
112  wantedDataType = gpuOp.getODSResults(0)[0].getType();
113 
114  Value atomicCmpData = Value();
115  // Operand index 1 of a load is the indices, trying to read them can crash.
116  if (storeData) {
117  Value maybeCmpData = adaptor.getODSOperands(1)[0];
118  if (maybeCmpData != memref)
119  atomicCmpData = maybeCmpData;
120  }
121 
122  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
123 
124  Type i32 = rewriter.getI32Type();
125  Type i16 = rewriter.getI16Type();
126 
127  // Get the type size in bytes.
128  DataLayout dataLayout = DataLayout::closest(gpuOp);
129  int64_t elementByteWidth =
130  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
131  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
132 
133  // If we want to load a vector<NxT> with total size <= 32
134  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
135  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
136  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
137  // so bitcast any floats to integers.
138  Type llvmBufferValType = llvmWantedDataType;
139  if (atomicCmpData) {
140  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
141  llvmBufferValType = this->getTypeConverter()->convertType(
142  rewriter.getIntegerType(floatType.getWidth()));
143  }
144  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
145  uint32_t vecLen = dataVector.getNumElements();
146  uint32_t elemBits =
147  dataLayout.getTypeSizeInBits(dataVector.getElementType());
148  uint32_t totalBits = elemBits * vecLen;
149  bool usePackedFp16 =
150  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
151  if (totalBits > maxVectorOpWidth)
152  return gpuOp.emitOpError(
153  "Total width of loads or stores must be no more than " +
154  Twine(maxVectorOpWidth) + " bits, but we call for " +
155  Twine(totalBits) +
156  " bits. This should've been caught in validation");
157  if (!usePackedFp16 && elemBits < 32) {
158  if (totalBits > 32) {
159  if (totalBits % 32 != 0)
160  return gpuOp.emitOpError("Load or store of more than 32-bits that "
161  "doesn't fit into words. Can't happen\n");
162  llvmBufferValType = this->typeConverter->convertType(
163  VectorType::get(totalBits / 32, i32));
164  } else {
165  llvmBufferValType = this->typeConverter->convertType(
166  rewriter.getIntegerType(totalBits));
167  }
168  }
169  }
170 
172  if (storeData) {
173  if (llvmBufferValType != llvmWantedDataType) {
174  Value castForStore =
175  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
176  args.push_back(castForStore);
177  } else {
178  args.push_back(storeData);
179  }
180  }
181 
182  if (atomicCmpData) {
183  if (llvmBufferValType != llvmWantedDataType) {
184  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
185  loc, llvmBufferValType, atomicCmpData);
186  args.push_back(castForCmp);
187  } else {
188  args.push_back(atomicCmpData);
189  }
190  }
191 
192  // Construct buffer descriptor from memref, attributes
193  int64_t offset = 0;
194  SmallVector<int64_t, 5> strides;
195  if (failed(memrefType.getStridesAndOffset(strides, offset)))
196  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
197 
198  MemRefDescriptor memrefDescriptor(memref);
199 
200  Value ptr = memrefDescriptor.bufferPtr(
201  rewriter, loc, *this->getTypeConverter(), memrefType);
202  // The stride value is always 0 for raw buffers. This also disables
203  // swizling.
204  Value stride = rewriter.create<LLVM::ConstantOp>(
205  loc, i16, rewriter.getI16IntegerAttr(0));
206  // Get the number of elements.
207  Value numRecords;
208  if (memrefType.hasStaticShape() &&
209  !llvm::any_of(strides, ShapedType::isDynamic)) {
210  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
211  ArrayRef<int64_t> shape = memrefType.getShape();
212  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
213  size = std::max(shape[i] * strides[i], size);
214  size = size * elementByteWidth;
215  assert(size < std::numeric_limits<uint32_t>::max() &&
216  "the memref buffer is too large");
217  numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
218  } else {
219  Value maxIndex;
220  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
221  Value size = memrefDescriptor.size(rewriter, loc, i);
222  Value stride = memrefDescriptor.stride(rewriter, loc, i);
223  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
224  maxIndex =
225  maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
226  : maxThisDim;
227  }
228  numRecords = rewriter.create<LLVM::MulOp>(
229  loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
230  }
231 
232  // Flag word:
233  // bits 0-11: dst sel, ignored by these intrinsics
234  // bits 12-14: data format (ignored, must be nonzero, 7=float)
235  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
236  // bit 19: In nested heap (0 here)
237  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
238  // bits 21-22: Index stride for swizzles (N/A)
239  // bit 23: Add thread ID (0)
240  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
241  // bits 25-26: Reserved (0)
242  // bit 27: Buffer is non-volatile (CDNA only)
243  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
244  // none, 3 = either swizzles or testing against offset field) RDNA only
245  // bits 30-31: Type (must be 0)
246  uint32_t flags = (7 << 12) | (4 << 15);
247  if (chipset.majorVersion >= 10) {
248  flags |= (1 << 24);
249  uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
250  flags |= (oob << 28);
251  }
252  Value flagsConst = createI32Constant(rewriter, loc, flags);
253  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
254  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
255  loc, rsrcType, ptr, stride, numRecords, flagsConst);
256  args.push_back(resource);
257 
258  // Indexing (voffset)
259  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
260  adaptor.getIndices(), strides);
261  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
262  indexOffset && *indexOffset > 0) {
263  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
264  voffset =
265  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
266  : extraOffsetConst;
267  }
268  voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
269  args.push_back(voffset);
270 
271  // SGPR offset.
272  Value sgprOffset = adaptor.getSgprOffset();
273  if (!sgprOffset)
274  sgprOffset = createI32Constant(rewriter, loc, 0);
275  sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
276  args.push_back(sgprOffset);
277 
278  // bit 0: GLC = 0 (atomics drop value, less coherency)
279  // bits 1-2: SLC, DLC = 0 (similarly)
280  // bit 3: swizzled (0 for raw)
281  args.push_back(createI32Constant(rewriter, loc, 0));
282 
283  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
284  llvmBufferValType);
285  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
287  if (lowered->getNumResults() == 1) {
288  Value replacement = lowered->getResult(0);
289  if (llvmBufferValType != llvmWantedDataType) {
290  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
291  replacement);
292  }
293  rewriter.replaceOp(gpuOp, replacement);
294  } else {
295  rewriter.eraseOp(gpuOp);
296  }
297  return success();
298  }
299 };
300 
301 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
302  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
303  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
304 
305  Chipset chipset;
306 
307  LogicalResult
308  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
309  ConversionPatternRewriter &rewriter) const override {
310  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
311 
312  if (requiresInlineAsm) {
313  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
314  LLVM::AsmDialect::AD_ATT);
315  const char *asmStr =
316  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
317  const char *constraints = "";
318  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
319  op,
320  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
321  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
322  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
323  /*operand_attrs=*/ArrayAttr());
324  return success();
325  }
326  if (chipset.majorVersion < 12) {
327  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
328  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
329  // Left in place in case someone disables the inline ASM path or future
330  // chipsets use the same bit pattern.
331  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
332 
333  int32_t ldsOnlyBits;
334  if (chipset.majorVersion == 11)
335  ldsOnlyBits = ldsOnlyBitsGfx11;
336  else if (chipset.majorVersion == 10)
337  ldsOnlyBits = ldsOnlyBitsGfx10;
338  else if (chipset.majorVersion <= 9)
339  ldsOnlyBits = ldsOnlyBitsGfx6789;
340  else
341  return op.emitOpError(
342  "don't know how to lower this for chipset major version")
343  << chipset.majorVersion;
344 
345  Location loc = op->getLoc();
346  rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
347  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
348  } else {
349  Location loc = op->getLoc();
350  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
351  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
352  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
353  }
354 
355  return success();
356  }
357 };
358 
359 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
360  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
361  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
362 
363  Chipset chipset;
364 
365  LogicalResult
366  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
367  ConversionPatternRewriter &rewriter) const override {
368  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
369  (uint32_t)op.getOpts());
370  return success();
371  }
372 };
373 
374 } // namespace
375 
376 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
377 /// and LLVM AMDGPU intrinsics convention.
378 ///
379 /// Specifically:
380 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
381 /// 2. If the element type is bfloat16, bitcast it to i16.
383  Location loc, Value input) {
384  Type inputType = input.getType();
385  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
386  if (vectorType.getElementType().isBF16())
387  return rewriter.create<LLVM::BitcastOp>(
388  loc, vectorType.clone(rewriter.getI16Type()), input);
389  if (vectorType.getElementType().isInteger(8)) {
390  return rewriter.create<LLVM::BitcastOp>(
391  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
392  }
393  }
394  return input;
395 }
396 
397 /// Push an input operand. If it is a float type, nothing to do. If it is
398 /// an integer type, then we need to also push its signdness (1 for signed, 0
399 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
400 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
401 /// of bfloat support in the WMMA intrinsics themselves.
403  Location loc,
404  const TypeConverter *typeConverter,
405  bool isUnsigned, Value llvmInput,
406  Value mlirInput,
407  SmallVector<Value, 4> &operands) {
408  Type inputType = llvmInput.getType();
409  auto vectorType = dyn_cast<VectorType>(inputType);
410  Type elemType = vectorType.getElementType();
411 
412  if (elemType.isBF16())
413  llvmInput = rewriter.create<LLVM::BitcastOp>(
414  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
415  if (!elemType.isInteger(8)) {
416  operands.push_back(llvmInput);
417  return;
418  }
419 
420  // We need to check the type of the input before conversion to properly test
421  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
422  // fp8/int8 information is lost during the conversion process.
423  auto mlirInputType = cast<VectorType>(mlirInput.getType());
424  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
425  if (isInputInt8) {
426  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
427  bool localIsUnsigned = isUnsigned;
428  if (elemType.isUnsignedInteger(8)) {
429  localIsUnsigned = true;
430  } else if (elemType.isSignedInteger(8)) {
431  localIsUnsigned = false;
432  }
433  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
434  operands.push_back(sign);
435  }
436 
437  int64_t numBytes = vectorType.getNumElements();
438  Type i32 = rewriter.getI32Type();
439  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
440  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
441  Value result = rewriter.createOrFold<LLVM::BitcastOp>(
442  loc, llvmVectorType32bits, llvmInput);
443  operands.push_back(result);
444 }
445 
446 /// Push the output operand. For many cases this is only pushing the output in
447 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
448 /// since the same numbers of VGPRs is used, we need to decide if to store the
449 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
450 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
451 /// be stored it in the upper part
453  Location loc,
454  const TypeConverter *typeConverter,
455  Value output, int32_t subwordOffset,
456  bool clamp, SmallVector<Value, 4> &operands) {
457  Type inputType = output.getType();
458  auto vectorType = dyn_cast<VectorType>(inputType);
459  Type elemType = vectorType.getElementType();
460  if (elemType.isBF16())
461  output = rewriter.create<LLVM::BitcastOp>(
462  loc, vectorType.clone(rewriter.getI16Type()), output);
463  operands.push_back(output);
464  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
465  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
466  } else if (elemType.isInteger(32)) {
467  operands.push_back(createI1Constant(rewriter, loc, clamp));
468  }
469 }
470 
471 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
472 /// if one exists. This includes checking to ensure the intrinsic is supported
473 /// on the architecture you are compiling for.
474 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
475  Chipset chipset) {
476  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
477  b = mfma.getBlocks();
478  Type sourceElem = mfma.getSourceA().getType();
479  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
480  sourceElem = sourceType.getElementType();
481  Type destElem = mfma.getDestC().getType();
482  if (auto destType = dyn_cast<VectorType>(destElem))
483  destElem = destType.getElementType();
484 
485  if (sourceElem.isF32() && destElem.isF32()) {
486  if (mfma.getReducePrecision() && chipset >= kGfx940) {
487  if (m == 32 && n == 32 && k == 4 && b == 1)
488  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
489  if (m == 16 && n == 16 && k == 8 && b == 1)
490  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
491  }
492  if (m == 32 && n == 32 && k == 1 && b == 2)
493  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
494  if (m == 16 && n == 16 && k == 1 && b == 4)
495  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
496  if (m == 4 && n == 4 && k == 1 && b == 16)
497  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
498  if (m == 32 && n == 32 && k == 2 && b == 1)
499  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
500  if (m == 16 && n == 16 && k == 4 && b == 1)
501  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
502  }
503 
504  if (sourceElem.isF16() && destElem.isF32()) {
505  if (m == 32 && n == 32 && k == 4 && b == 2)
506  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
507  if (m == 16 && n == 16 && k == 4 && b == 4)
508  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
509  if (m == 4 && n == 4 && k == 4 && b == 16)
510  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
511  if (m == 32 && n == 32 && k == 8 && b == 1)
512  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
513  if (m == 16 && n == 16 && k == 16 && b == 1)
514  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
515  }
516 
517  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
518  if (m == 32 && n == 32 && k == 4 && b == 2)
519  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
520  if (m == 16 && n == 16 && k == 4 && b == 4)
521  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
522  if (m == 4 && n == 4 && k == 4 && b == 16)
523  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
524  if (m == 32 && n == 32 && k == 8 && b == 1)
525  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
526  if (m == 16 && n == 16 && k == 16 && b == 1)
527  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
528  }
529 
530  if (sourceElem.isBF16() && destElem.isF32()) {
531  if (m == 32 && n == 32 && k == 2 && b == 2)
532  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
533  if (m == 16 && n == 16 && k == 2 && b == 4)
534  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
535  if (m == 4 && n == 4 && k == 2 && b == 16)
536  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
537  if (m == 32 && n == 32 && k == 4 && b == 1)
538  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
539  if (m == 16 && n == 16 && k == 8 && b == 1)
540  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
541  }
542 
543  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
544  if (m == 32 && n == 32 && k == 4 && b == 2)
545  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
546  if (m == 16 && n == 16 && k == 4 && b == 4)
547  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
548  if (m == 4 && n == 4 && k == 4 && b == 16)
549  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
550  if (m == 32 && n == 32 && k == 8 && b == 1)
551  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
552  if (m == 16 && n == 16 && k == 16 && b == 1)
553  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
554  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
555  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
556  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
557  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
558  }
559 
560  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
561  if (m == 16 && n == 16 && k == 4 && b == 1)
562  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
563  if (m == 4 && n == 4 && k == 4 && b == 4)
564  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
565  }
566 
567  if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
568  chipset >= kGfx940) {
569  // Known to be correct because there are no scalar f8 instructions and
570  // because a length mismatch will have been caught by the verifier.
571  Type sourceBElem =
572  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
573  if (m == 16 && n == 16 && k == 32 && b == 1) {
574  if (isa<Float8E5M2FNUZType>(sourceBElem))
575  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
576  if (isa<Float8E4M3FNUZType>(sourceBElem))
577  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
578  }
579  if (m == 32 && n == 32 && k == 16 && b == 1) {
580  if (isa<Float8E5M2FNUZType>(sourceBElem))
581  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
582  if (isa<Float8E4M3FNUZType>(sourceBElem))
583  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
584  }
585  }
586 
587  if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
588  chipset >= kGfx940) {
589  Type sourceBElem =
590  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
591  if (m == 16 && n == 16 && k == 32 && b == 1) {
592  if (isa<Float8E5M2FNUZType>(sourceBElem))
593  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
594  if (isa<Float8E4M3FNUZType>(sourceBElem))
595  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
596  }
597  if (m == 32 && n == 32 && k == 16 && b == 1) {
598  if (isa<Float8E5M2FNUZType>(sourceBElem))
599  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
600  if (isa<Float8E4M3FNUZType>(sourceBElem))
601  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
602  }
603  }
604 
605  return std::nullopt;
606 }
607 
608 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
609 /// if one exists. This includes checking to ensure the intrinsic is supported
610 /// on the architecture you are compiling for.
611 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
612  Chipset chipset) {
613  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
614  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
615  auto elemSourceType = sourceVectorType.getElementType();
616  auto elemDestType = destVectorType.getElementType();
617 
618  if (elemSourceType.isF16() && elemDestType.isF32())
619  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
620  if (elemSourceType.isBF16() && elemDestType.isF32())
621  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
622  if (elemSourceType.isF16() && elemDestType.isF16())
623  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
624  if (elemSourceType.isBF16() && elemDestType.isBF16())
625  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
626  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
627  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
628  if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
629  return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
630  if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
631  return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
632  return std::nullopt;
633 }
634 
635 namespace {
636 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
637  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
638  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
639 
640  Chipset chipset;
641 
642  LogicalResult
643  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
644  ConversionPatternRewriter &rewriter) const override {
645  Location loc = op.getLoc();
646  Type outType = typeConverter->convertType(op.getDestD().getType());
647  Type intrinsicOutType = outType;
648  if (auto outVecType = dyn_cast<VectorType>(outType))
649  if (outVecType.getElementType().isBF16())
650  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
651 
652  if (chipset.majorVersion != 9 || chipset < kGfx908)
653  return op->emitOpError("MFMA only supported on gfx908+");
654  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
655  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
656  if (chipset < kGfx940)
657  return op.emitOpError("negation unsupported on older than gfx940");
658  getBlgpField |=
659  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
660  }
661  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
662  if (!maybeIntrinsic.has_value())
663  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
664  OperationState loweredOp(loc, *maybeIntrinsic);
665  loweredOp.addTypes(intrinsicOutType);
666  loweredOp.addOperands(
667  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
668  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
669  adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
670  createI32Constant(rewriter, loc, op.getAbid()),
671  createI32Constant(rewriter, loc, getBlgpField)});
672  Value lowered = rewriter.create(loweredOp)->getResult(0);
673  if (outType != intrinsicOutType)
674  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
675  rewriter.replaceOp(op, lowered);
676  return success();
677  }
678 };
679 
680 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
681  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
682  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
683 
684  Chipset chipset;
685 
686  LogicalResult
687  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
688  ConversionPatternRewriter &rewriter) const override {
689  Location loc = op.getLoc();
690  auto outType =
691  typeConverter->convertType<VectorType>(op.getDestD().getType());
692  if (!outType)
693  return rewriter.notifyMatchFailure(op, "type conversion failed");
694 
695  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
696  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
697 
698  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
699  // need to bitcast bfloats to i16 and then bitcast them back.
700  VectorType rawOutType = outType;
701  if (outType.getElementType().isBF16())
702  rawOutType = outType.clone(rewriter.getI16Type());
703 
704  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
705 
706  if (!maybeIntrinsic.has_value())
707  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
708 
709  OperationState loweredOp(loc, *maybeIntrinsic);
710  loweredOp.addTypes(rawOutType);
711 
712  SmallVector<Value, 4> operands;
713  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
714  adaptor.getSourceA(), op.getSourceA(), operands);
715  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
716  adaptor.getSourceB(), op.getSourceB(), operands);
717  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
718  op.getSubwordOffset(), op.getClamp(), operands);
719 
720  loweredOp.addOperands(operands);
721  Operation *lowered = rewriter.create(loweredOp);
722 
723  Operation *maybeCastBack = lowered;
724  if (rawOutType != outType)
725  maybeCastBack =
726  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
727  rewriter.replaceOp(op, maybeCastBack->getResults());
728 
729  return success();
730  }
731 };
732 
733 namespace {
734 struct ExtPackedFp8OpLowering final
735  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
736  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
737  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
738  chipset(chipset) {}
739  Chipset chipset;
740 
741  LogicalResult
742  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
743  ConversionPatternRewriter &rewriter) const override;
744 };
745 
746 struct PackedTrunc2xFp8OpLowering final
747  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
748  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
749  Chipset chipset)
750  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
751  chipset(chipset) {}
752  Chipset chipset;
753 
754  LogicalResult
755  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
756  ConversionPatternRewriter &rewriter) const override;
757 };
758 
759 struct PackedStochRoundFp8OpLowering final
760  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
761  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
762  Chipset chipset)
763  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
764  chipset(chipset) {}
765  Chipset chipset;
766 
767  LogicalResult
768  matchAndRewrite(PackedStochRoundFp8Op op,
769  PackedStochRoundFp8OpAdaptor adaptor,
770  ConversionPatternRewriter &rewriter) const override;
771 };
772 } // end namespace
773 
774 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
775  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
776  ConversionPatternRewriter &rewriter) const {
777  Location loc = op.getLoc();
778  if (chipset.majorVersion != 9 || chipset < kGfx940)
779  return rewriter.notifyMatchFailure(
780  loc, "Fp8 conversion instructions are not available on target "
781  "architecture and their emulation is not implemented");
782  Type v4i8 =
783  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
784  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
785  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
786 
787  Value source = adaptor.getSource();
788  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
789  Type sourceElemType = getElementTypeOrSelf(op.getSource());
790  // Extend to a v4i8
791  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
792  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
793  if (!sourceVecType) {
794  longVec = rewriter.create<LLVM::InsertElementOp>(
795  loc, longVec, source, createI32Constant(rewriter, loc, 0));
796  } else {
797  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
798  Value idx = createI32Constant(rewriter, loc, i);
799  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
800  longVec =
801  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
802  }
803  }
804  source = longVec;
805  }
806  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
807  Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
808  if (isa<Float8E5M2FNUZType>(sourceElemType)) {
809  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
810  wordSel);
811  } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
812  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
813  wordSel);
814  }
815  return success();
816 }
817 
818 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
819  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
820  ConversionPatternRewriter &rewriter) const {
821  Location loc = op.getLoc();
822  if (chipset.majorVersion != 9 || chipset < kGfx940)
823  return rewriter.notifyMatchFailure(
824  loc, "Fp8 conversion instructions are not available on target "
825  "architecture and their emulation is not implemented");
826  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
827 
828  Type resultType = op.getResult().getType();
829  Type resultElemType = getElementTypeOrSelf(resultType);
830 
831  Value sourceA = adaptor.getSourceA();
832  Value sourceB = adaptor.getSourceB();
833  if (!sourceB)
834  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
835  Value existing = adaptor.getExisting();
836  if (existing)
837  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
838  else
839  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
840  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
841 
842  Value result;
843  if (isa<Float8E5M2FNUZType>(resultElemType))
844  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
845  existing, wordSel);
846  else if (isa<Float8E4M3FNUZType>(resultElemType))
847  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
848  existing, wordSel);
849 
850  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
851  op, getTypeConverter()->convertType(resultType), result);
852  return success();
853 }
854 
855 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
856  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
857  ConversionPatternRewriter &rewriter) const {
858  Location loc = op.getLoc();
859  if (chipset.majorVersion != 9 || chipset < kGfx940)
860  return rewriter.notifyMatchFailure(
861  loc, "Fp8 conversion instructions are not available on target "
862  "architecture and their emulation is not implemented");
863  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
864 
865  Type resultType = op.getResult().getType();
866  Type resultElemType = getElementTypeOrSelf(resultType);
867 
868  Value source = adaptor.getSource();
869  Value stoch = adaptor.getStochiasticParam();
870  Value existing = adaptor.getExisting();
871  if (existing)
872  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
873  else
874  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
875  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
876 
877  Value result;
878  if (isa<Float8E5M2FNUZType>(resultElemType))
879  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
880  existing, byteSel);
881  else if (isa<Float8E4M3FNUZType>(resultElemType))
882  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
883  existing, byteSel);
884 
885  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
886  op, getTypeConverter()->convertType(resultType), result);
887  return success();
888 }
889 
890 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
891 // operation into the corresponding ROCDL instructions.
892 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
893  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
894  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
895  Chipset chipset;
896 
897  LogicalResult
898  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
899  ConversionPatternRewriter &rewriter) const override {
900 
901  // Convert the source operand to the corresponding LLVM type
902  Location loc = DppOp.getLoc();
903  Value src = adaptor.getSrc();
904  Value old = adaptor.getOld();
905  Type srcType = src.getType();
906  Type oldType = old.getType();
907  Type llvmType = nullptr;
908  if (srcType.getIntOrFloatBitWidth() < 32) {
909  llvmType = rewriter.getI32Type();
910  } else if (isa<FloatType>(srcType)) {
911  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
912  ? rewriter.getF32Type()
913  : rewriter.getF64Type();
914  } else if (isa<IntegerType>(srcType)) {
915  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
916  ? rewriter.getI32Type()
917  : rewriter.getI64Type();
918  }
919  auto llvmSrcIntType = typeConverter->convertType(
920  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
921 
922  // If the source type is less of 32, use bitcast to convert it to i32.
923  auto convertOperand = [&](Value operand, Type operandType) {
924  if (operandType.getIntOrFloatBitWidth() <= 16) {
925  if (llvm::isa<FloatType>(operandType)) {
926  operand =
927  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
928  }
929  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
930  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
931  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
932  operand = rewriter.create<LLVM::InsertElementOp>(
933  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
934  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
935  }
936  return operand;
937  };
938 
939  src = convertOperand(src, srcType);
940  old = convertOperand(old, oldType);
941 
942  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
943  enum DppCtrl : unsigned {
944  ROW_SHL0 = 0x100,
945  ROW_SHR0 = 0x110,
946  ROW_ROR0 = 0x120,
947  WAVE_SHL1 = 0x130,
948  WAVE_ROL1 = 0x134,
949  WAVE_SHR1 = 0x138,
950  WAVE_ROR1 = 0x13C,
951  ROW_MIRROR = 0x140,
952  ROW_HALF_MIRROR = 0x141,
953  BCAST15 = 0x142,
954  BCAST31 = 0x143,
955  };
956 
957  auto kind = DppOp.getKind();
958  auto permArgument = DppOp.getPermArgument();
959  uint32_t DppCtrl = 0;
960 
961  switch (kind) {
962 
963  case DPPPerm::quad_perm:
964  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
965  int32_t i = 0;
966  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
967  uint32_t num = elem.getInt();
968  DppCtrl |= num << (i * 2);
969  i++;
970  }
971  }
972  break;
973  case DPPPerm::row_shl:
974  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
975  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
976  }
977  break;
978  case DPPPerm::row_shr:
979  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
980  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
981  }
982  break;
983  case DPPPerm::row_ror:
984  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
985  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
986  }
987  break;
988  case DPPPerm::wave_shl:
989  DppCtrl = DppCtrl::WAVE_SHL1;
990  break;
991  case DPPPerm::wave_shr:
992  DppCtrl = DppCtrl::WAVE_SHR1;
993  break;
994  case DPPPerm::wave_rol:
995  DppCtrl = DppCtrl::WAVE_ROL1;
996  break;
997  case DPPPerm::wave_ror:
998  DppCtrl = DppCtrl::WAVE_ROR1;
999  break;
1000  case DPPPerm::row_mirror:
1001  DppCtrl = DppCtrl::ROW_MIRROR;
1002  break;
1003  case DPPPerm::row_half_mirror:
1004  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1005  break;
1006  case DPPPerm::row_bcast_15:
1007  DppCtrl = DppCtrl::BCAST15;
1008  break;
1009  case DPPPerm::row_bcast_31:
1010  DppCtrl = DppCtrl::BCAST31;
1011  break;
1012  }
1013 
1014  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1015  // constants
1016  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1017  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1018  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1019 
1020  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1021  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1022  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1023 
1024  Value result = dppMovOp.getRes();
1025  if (srcType.getIntOrFloatBitWidth() < 32) {
1026  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1027  if (!llvm::isa<IntegerType>(srcType)) {
1028  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1029  }
1030  }
1031 
1032  // We are replacing the AMDGPU_DPPOp instruction with the new
1033  // ROCDL_DPPMovOp instruction
1034  rewriter.replaceOp(DppOp, ValueRange(result));
1035  return success();
1036  }
1037 };
1038 
1039 struct ConvertAMDGPUToROCDLPass
1040  : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1041  ConvertAMDGPUToROCDLPass() = default;
1042 
1043  void runOnOperation() override {
1044  MLIRContext *ctx = &getContext();
1045  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1046  if (failed(maybeChipset)) {
1047  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1048  return signalPassFailure();
1049  }
1050 
1052  LLVMTypeConverter converter(ctx);
1053  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1054  LLVMConversionTarget target(getContext());
1055  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1056  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1057  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1058  if (failed(applyPartialConversion(getOperation(), target,
1059  std::move(patterns))))
1060  signalPassFailure();
1061  }
1062 };
1063 } // namespace
1064 
1066  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1067  Chipset chipset) {
1068  patterns
1069  .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1070  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1071  RawBufferOpLowering<RawBufferAtomicFaddOp,
1072  ROCDL::RawPtrBufferAtomicFaddOp>,
1073  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1074  ROCDL::RawPtrBufferAtomicFmaxOp>,
1075  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1076  ROCDL::RawPtrBufferAtomicSmaxOp>,
1077  RawBufferOpLowering<RawBufferAtomicUminOp,
1078  ROCDL::RawPtrBufferAtomicUminOp>,
1079  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1080  ROCDL::RawPtrBufferAtomicCmpSwap>,
1081  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1082  MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1083  PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1084  chipset);
1085 }
1086 
1087 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
1088  return std::make_unique<ConvertAMDGPUToROCDLPass>();
1089 }
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerType getI16Type()
Definition: Builders.cpp:61
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:213
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IntegerType getI8Type()
Definition: Builders.cpp:59
FloatType getF64Type()
Definition: Builders.cpp:45
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:68
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:80
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:47
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14