MLIR  20.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include <optional>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::amdgpu;
32 
34  Location loc, int32_t value) {
35  Type llvmI32 = rewriter.getI32Type();
36  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
37 }
38 
40  bool value) {
41  Type llvmI1 = rewriter.getI1Type();
42  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
43 }
44 
45 namespace {
46 // Define commonly used chipsets versions for convenience.
47 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49 constexpr Chipset kGfx940 = Chipset(9, 4, 0);
50 
51 /// Define lowering patterns for raw buffer ops
52 template <typename GpuOp, typename Intrinsic>
53 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
54  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
55  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
56 
57  Chipset chipset;
58  static constexpr uint32_t maxVectorOpWidth = 128;
59 
60  LogicalResult
61  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
62  ConversionPatternRewriter &rewriter) const override {
63  Location loc = gpuOp.getLoc();
64  Value memref = adaptor.getMemref();
65  Value unconvertedMemref = gpuOp.getMemref();
66  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
67 
68  if (chipset.majorVersion < 9)
69  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
70 
71  Value storeData = adaptor.getODSOperands(0)[0];
72  if (storeData == memref) // no write component to this op
73  storeData = Value();
74  Type wantedDataType;
75  if (storeData)
76  wantedDataType = storeData.getType();
77  else
78  wantedDataType = gpuOp.getODSResults(0)[0].getType();
79 
80  Value atomicCmpData = Value();
81  // Operand index 1 of a load is the indices, trying to read them can crash.
82  if (storeData) {
83  Value maybeCmpData = adaptor.getODSOperands(1)[0];
84  if (maybeCmpData != memref)
85  atomicCmpData = maybeCmpData;
86  }
87 
88  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
89 
90  Type i32 = rewriter.getI32Type();
91  Type llvmI32 = this->typeConverter->convertType(i32);
92  Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
93 
94  int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
95  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
96 
97  // If we want to load a vector<NxT> with total size <= 32
98  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
99  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
100  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
101  // so bitcast any floats to integers.
102  Type llvmBufferValType = llvmWantedDataType;
103  if (atomicCmpData) {
104  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
105  llvmBufferValType = this->getTypeConverter()->convertType(
106  rewriter.getIntegerType(floatType.getWidth()));
107  }
108  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
109  uint32_t vecLen = dataVector.getNumElements();
110  uint32_t elemBits = dataVector.getElementTypeBitWidth();
111  uint32_t totalBits = elemBits * vecLen;
112  bool usePackedFp16 =
113  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
114  if (totalBits > maxVectorOpWidth)
115  return gpuOp.emitOpError(
116  "Total width of loads or stores must be no more than " +
117  Twine(maxVectorOpWidth) + " bits, but we call for " +
118  Twine(totalBits) +
119  " bits. This should've been caught in validation");
120  if (!usePackedFp16 && elemBits < 32) {
121  if (totalBits > 32) {
122  if (totalBits % 32 != 0)
123  return gpuOp.emitOpError("Load or store of more than 32-bits that "
124  "doesn't fit into words. Can't happen\n");
125  llvmBufferValType = this->typeConverter->convertType(
126  VectorType::get(totalBits / 32, i32));
127  } else {
128  llvmBufferValType = this->typeConverter->convertType(
129  rewriter.getIntegerType(totalBits));
130  }
131  }
132  }
133 
135  if (storeData) {
136  if (llvmBufferValType != llvmWantedDataType) {
137  Value castForStore =
138  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139  args.push_back(castForStore);
140  } else {
141  args.push_back(storeData);
142  }
143  }
144 
145  if (atomicCmpData) {
146  if (llvmBufferValType != llvmWantedDataType) {
147  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
148  loc, llvmBufferValType, atomicCmpData);
149  args.push_back(castForCmp);
150  } else {
151  args.push_back(atomicCmpData);
152  }
153  }
154 
155  // Construct buffer descriptor from memref, attributes
156  int64_t offset = 0;
157  SmallVector<int64_t, 5> strides;
158  if (failed(getStridesAndOffset(memrefType, strides, offset)))
159  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
160 
161  MemRefDescriptor memrefDescriptor(memref);
162 
163  Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
164  // The stride value is always 0 for raw buffers. This also disables
165  // swizling.
166  Value stride = rewriter.create<LLVM::ConstantOp>(
167  loc, llvmI16, rewriter.getI16IntegerAttr(0));
168  Value numRecords;
169  if (memrefType.hasStaticShape()) {
170  numRecords = createI32Constant(
171  rewriter, loc,
172  static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
173  } else {
174  Value maxIndex;
175  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176  Value size = memrefDescriptor.size(rewriter, loc, i);
177  Value stride = memrefDescriptor.stride(rewriter, loc, i);
178  stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
179  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
180  maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
181  maxThisDim)
182  : maxThisDim;
183  }
184  numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
185  }
186 
187  // Flag word:
188  // bits 0-11: dst sel, ignored by these intrinsics
189  // bits 12-14: data format (ignored, must be nonzero, 7=float)
190  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
191  // bit 19: In nested heap (0 here)
192  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
193  // bits 21-22: Index stride for swizzles (N/A)
194  // bit 23: Add thread ID (0)
195  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
196  // bits 25-26: Reserved (0)
197  // bit 27: Buffer is non-volatile (CDNA only)
198  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
199  // none, 3 = either swizzles or testing against offset field) RDNA only
200  // bits 30-31: Type (must be 0)
201  uint32_t flags = (7 << 12) | (4 << 15);
202  if (chipset.majorVersion >= 10) {
203  flags |= (1 << 24);
204  uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205  flags |= (oob << 28);
206  }
207  Value flagsConst = createI32Constant(rewriter, loc, flags);
208  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
209  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
210  loc, rsrcType, ptr, stride, numRecords, flagsConst);
211  args.push_back(resource);
212 
213  // Indexing (voffset)
214  Value voffset = createI32Constant(rewriter, loc, 0);
215  for (auto pair : llvm::enumerate(adaptor.getIndices())) {
216  size_t i = pair.index();
217  Value index = pair.value();
218  Value strideOp;
219  if (ShapedType::isDynamic(strides[i])) {
220  strideOp = rewriter.create<LLVM::MulOp>(
221  loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
222  } else {
223  strideOp =
224  createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
225  }
226  index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
227  voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
228  }
229  if (adaptor.getIndexOffset()) {
230  int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
231  Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
232  voffset =
233  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
234  : extraOffsetConst;
235  }
236  args.push_back(voffset);
237 
238  Value sgprOffset = adaptor.getSgprOffset();
239  if (!sgprOffset)
240  sgprOffset = createI32Constant(rewriter, loc, 0);
241  if (ShapedType::isDynamic(offset))
242  sgprOffset = rewriter.create<LLVM::AddOp>(
243  loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
244  else if (offset > 0)
245  sgprOffset = rewriter.create<LLVM::AddOp>(
246  loc, sgprOffset, createI32Constant(rewriter, loc, offset));
247  args.push_back(sgprOffset);
248 
249  // bit 0: GLC = 0 (atomics drop value, less coherency)
250  // bits 1-2: SLC, DLC = 0 (similarly)
251  // bit 3: swizzled (0 for raw)
252  args.push_back(createI32Constant(rewriter, loc, 0));
253 
254  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
255  llvmBufferValType);
256  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
258  if (lowered->getNumResults() == 1) {
259  Value replacement = lowered->getResult(0);
260  if (llvmBufferValType != llvmWantedDataType) {
261  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
262  replacement);
263  }
264  rewriter.replaceOp(gpuOp, replacement);
265  } else {
266  rewriter.eraseOp(gpuOp);
267  }
268  return success();
269  }
270 };
271 
272 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
273  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
274  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
275 
276  Chipset chipset;
277 
278  LogicalResult
279  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280  ConversionPatternRewriter &rewriter) const override {
281  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
282 
283  if (requiresInlineAsm) {
284  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
285  LLVM::AsmDialect::AD_ATT);
286  const char *asmStr =
287  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
288  const char *constraints = "";
289  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
290  op,
291  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
292  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
293  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
294  /*operand_attrs=*/ArrayAttr());
295  return success();
296  }
297  if (chipset.majorVersion < 12) {
298  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
299  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
300  // Left in place in case someone disables the inline ASM path or future
301  // chipsets use the same bit pattern.
302  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
303 
304  int32_t ldsOnlyBits;
305  if (chipset.majorVersion == 11)
306  ldsOnlyBits = ldsOnlyBitsGfx11;
307  else if (chipset.majorVersion == 10)
308  ldsOnlyBits = ldsOnlyBitsGfx10;
309  else if (chipset.majorVersion <= 9)
310  ldsOnlyBits = ldsOnlyBitsGfx6789;
311  else
312  return op.emitOpError(
313  "don't know how to lower this for chipset major version")
314  << chipset.majorVersion;
315 
316  Location loc = op->getLoc();
317  rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
318  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
319  } else {
320  Location loc = op->getLoc();
321  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
322  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
323  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
324  }
325 
326  return success();
327  }
328 };
329 
330 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
331  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
332  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
333 
334  Chipset chipset;
335 
336  LogicalResult
337  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
338  ConversionPatternRewriter &rewriter) const override {
339  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
340  (uint32_t)op.getOpts());
341  return success();
342  }
343 };
344 
345 } // namespace
346 
347 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
348 /// and LLVM AMDGPU intrinsics convention.
349 ///
350 /// Specifically:
351 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
352 /// 2. If the element type is bfloat16, bitcast it to i16.
354  Location loc, Value input) {
355  Type inputType = input.getType();
356  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
357  if (vectorType.getElementType().isBF16())
358  return rewriter.create<LLVM::BitcastOp>(
359  loc, vectorType.clone(rewriter.getI16Type()), input);
360  if (vectorType.getElementType().isInteger(8)) {
361  return rewriter.create<LLVM::BitcastOp>(
362  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
363  }
364  }
365  return input;
366 }
367 
368 /// Push an input operand. If it is a float type, nothing to do. If it is
369 /// an integer type, then we need to also push its signdness (1 for signed, 0
370 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
371 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
372 /// of bfloat support in the WMMA intrinsics themselves.
374  Location loc,
375  const TypeConverter *typeConverter,
376  bool isUnsigned, Value llvmInput,
377  Value mlirInput,
378  SmallVector<Value, 4> &operands) {
379  Type inputType = llvmInput.getType();
380  auto vectorType = dyn_cast<VectorType>(inputType);
381  Type elemType = vectorType.getElementType();
382 
383  if (elemType.isBF16())
384  llvmInput = rewriter.create<LLVM::BitcastOp>(
385  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
386  if (!elemType.isInteger(8)) {
387  operands.push_back(llvmInput);
388  return;
389  }
390 
391  // We need to check the type of the input before conversion to properly test
392  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
393  // fp8/int8 information is lost during the conversion process.
394  auto mlirInputType = cast<VectorType>(mlirInput.getType());
395  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
396  if (isInputInt8) {
397  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
398  bool localIsUnsigned = isUnsigned;
399  if (elemType.isUnsignedInteger(8)) {
400  localIsUnsigned = true;
401  } else if (elemType.isSignedInteger(8)) {
402  localIsUnsigned = false;
403  }
404  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
405  operands.push_back(sign);
406  }
407 
408  int64_t numBytes = vectorType.getNumElements();
409  Type i32 = rewriter.getI32Type();
410  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
411  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
412  Value result = rewriter.createOrFold<LLVM::BitcastOp>(
413  loc, llvmVectorType32bits, llvmInput);
414  operands.push_back(result);
415 }
416 
417 /// Push the output operand. For many cases this is only pushing the output in
418 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
419 /// since the same numbers of VGPRs is used, we need to decide if to store the
420 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
421 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
422 /// be stored it in the upper part
424  Location loc,
425  const TypeConverter *typeConverter,
426  Value output, int32_t subwordOffset,
427  bool clamp, SmallVector<Value, 4> &operands) {
428  Type inputType = output.getType();
429  auto vectorType = dyn_cast<VectorType>(inputType);
430  Type elemType = vectorType.getElementType();
431  if (elemType.isBF16())
432  output = rewriter.create<LLVM::BitcastOp>(
433  loc, vectorType.clone(rewriter.getI16Type()), output);
434  operands.push_back(output);
435  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
436  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
437  } else if (elemType.isInteger(32)) {
438  operands.push_back(createI1Constant(rewriter, loc, clamp));
439  }
440 }
441 
442 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
443 /// if one exists. This includes checking to ensure the intrinsic is supported
444 /// on the architecture you are compiling for.
445 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
446  Chipset chipset) {
447  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
448  b = mfma.getBlocks();
449  Type sourceElem = mfma.getSourceA().getType();
450  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
451  sourceElem = sourceType.getElementType();
452  Type destElem = mfma.getDestC().getType();
453  if (auto destType = dyn_cast<VectorType>(destElem))
454  destElem = destType.getElementType();
455 
456  if (sourceElem.isF32() && destElem.isF32()) {
457  if (mfma.getReducePrecision() && chipset >= kGfx940) {
458  if (m == 32 && n == 32 && k == 4 && b == 1)
459  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
460  if (m == 16 && n == 16 && k == 8 && b == 1)
461  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
462  }
463  if (m == 32 && n == 32 && k == 1 && b == 2)
464  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
465  if (m == 16 && n == 16 && k == 1 && b == 4)
466  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
467  if (m == 4 && n == 4 && k == 1 && b == 16)
468  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
469  if (m == 32 && n == 32 && k == 2 && b == 1)
470  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
471  if (m == 16 && n == 16 && k == 4 && b == 1)
472  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
473  }
474 
475  if (sourceElem.isF16() && destElem.isF32()) {
476  if (m == 32 && n == 32 && k == 4 && b == 2)
477  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
478  if (m == 16 && n == 16 && k == 4 && b == 4)
479  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
480  if (m == 4 && n == 4 && k == 4 && b == 16)
481  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
482  if (m == 32 && n == 32 && k == 8 && b == 1)
483  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
484  if (m == 16 && n == 16 && k == 16 && b == 1)
485  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
486  }
487 
488  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
489  if (m == 32 && n == 32 && k == 4 && b == 2)
490  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
491  if (m == 16 && n == 16 && k == 4 && b == 4)
492  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
493  if (m == 4 && n == 4 && k == 4 && b == 16)
494  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
495  if (m == 32 && n == 32 && k == 8 && b == 1)
496  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
497  if (m == 16 && n == 16 && k == 16 && b == 1)
498  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
499  }
500 
501  if (sourceElem.isBF16() && destElem.isF32()) {
502  if (m == 32 && n == 32 && k == 2 && b == 2)
503  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
504  if (m == 16 && n == 16 && k == 2 && b == 4)
505  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
506  if (m == 4 && n == 4 && k == 2 && b == 16)
507  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
508  if (m == 32 && n == 32 && k == 4 && b == 1)
509  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
510  if (m == 16 && n == 16 && k == 8 && b == 1)
511  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
512  }
513 
514  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
515  if (m == 32 && n == 32 && k == 4 && b == 2)
516  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
517  if (m == 16 && n == 16 && k == 4 && b == 4)
518  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
519  if (m == 4 && n == 4 && k == 4 && b == 16)
520  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
521  if (m == 32 && n == 32 && k == 8 && b == 1)
522  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
523  if (m == 16 && n == 16 && k == 16 && b == 1)
524  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
525  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
526  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
527  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
528  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
529  }
530 
531  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
532  if (m == 16 && n == 16 && k == 4 && b == 1)
533  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
534  if (m == 4 && n == 4 && k == 4 && b == 4)
535  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
536  }
537 
538  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
539  // Known to be correct because there are no scalar f8 instructions and
540  // because a length mismatch will have been caught by the verifier.
541  Type sourceBElem =
542  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
543  if (m == 16 && n == 16 && k == 32 && b == 1) {
544  if (sourceBElem.isFloat8E5M2FNUZ())
545  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
546  if (sourceBElem.isFloat8E4M3FNUZ())
547  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
548  }
549  if (m == 32 && n == 32 && k == 16 && b == 1) {
550  if (sourceBElem.isFloat8E5M2FNUZ())
551  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
552  if (sourceBElem.isFloat8E4M3FNUZ())
553  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
554  }
555  }
556 
557  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
558  Type sourceBElem =
559  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
560  if (m == 16 && n == 16 && k == 32 && b == 1) {
561  if (sourceBElem.isFloat8E5M2FNUZ())
562  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
563  if (sourceBElem.isFloat8E4M3FNUZ())
564  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
565  }
566  if (m == 32 && n == 32 && k == 16 && b == 1) {
567  if (sourceBElem.isFloat8E5M2FNUZ())
568  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
569  if (sourceBElem.isFloat8E4M3FNUZ())
570  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
571  }
572  }
573 
574  return std::nullopt;
575 }
576 
577 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
578 /// if one exists. This includes checking to ensure the intrinsic is supported
579 /// on the architecture you are compiling for.
580 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
581  Chipset chipset) {
582  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
583  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
584  auto elemSourceType = sourceVectorType.getElementType();
585  auto elemDestType = destVectorType.getElementType();
586 
587  if (elemSourceType.isF16() && elemDestType.isF32())
588  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
589  if (elemSourceType.isBF16() && elemDestType.isF32())
590  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
591  if (elemSourceType.isF16() && elemDestType.isF16())
592  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
593  if (elemSourceType.isBF16() && elemDestType.isBF16())
594  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
595  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
596  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
597  if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
598  return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
599  if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
600  return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
601  return std::nullopt;
602 }
603 
604 namespace {
605 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
606  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
607  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
608 
609  Chipset chipset;
610 
611  LogicalResult
612  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
613  ConversionPatternRewriter &rewriter) const override {
614  Location loc = op.getLoc();
615  Type outType = typeConverter->convertType(op.getDestD().getType());
616  Type intrinsicOutType = outType;
617  if (auto outVecType = dyn_cast<VectorType>(outType))
618  if (outVecType.getElementType().isBF16())
619  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
620 
621  if (chipset.majorVersion != 9 || chipset < kGfx908)
622  return op->emitOpError("MFMA only supported on gfx908+");
623  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
624  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
625  if (chipset < kGfx940)
626  return op.emitOpError("negation unsupported on older than gfx940");
627  getBlgpField |=
628  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
629  }
630  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
631  if (!maybeIntrinsic.has_value())
632  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
633  OperationState loweredOp(loc, *maybeIntrinsic);
634  loweredOp.addTypes(intrinsicOutType);
635  loweredOp.addOperands(
636  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
637  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
638  adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
639  createI32Constant(rewriter, loc, op.getAbid()),
640  createI32Constant(rewriter, loc, getBlgpField)});
641  Value lowered = rewriter.create(loweredOp)->getResult(0);
642  if (outType != intrinsicOutType)
643  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
644  rewriter.replaceOp(op, lowered);
645  return success();
646  }
647 };
648 
649 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
650  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
651  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
652 
653  Chipset chipset;
654 
655  LogicalResult
656  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
657  ConversionPatternRewriter &rewriter) const override {
658  Location loc = op.getLoc();
659  auto outType =
660  typeConverter->convertType<VectorType>(op.getDestD().getType());
661  if (!outType)
662  return rewriter.notifyMatchFailure(op, "type conversion failed");
663 
664  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
665  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
666 
667  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
668  // need to bitcast bfloats to i16 and then bitcast them back.
669  VectorType rawOutType = outType;
670  if (outType.getElementType().isBF16())
671  rawOutType = outType.clone(rewriter.getI16Type());
672 
673  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
674 
675  if (!maybeIntrinsic.has_value())
676  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
677 
678  OperationState loweredOp(loc, *maybeIntrinsic);
679  loweredOp.addTypes(rawOutType);
680 
681  SmallVector<Value, 4> operands;
682  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
683  adaptor.getSourceA(), op.getSourceA(), operands);
684  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
685  adaptor.getSourceB(), op.getSourceB(), operands);
686  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
687  op.getSubwordOffset(), op.getClamp(), operands);
688 
689  loweredOp.addOperands(operands);
690  Operation *lowered = rewriter.create(loweredOp);
691 
692  Operation *maybeCastBack = lowered;
693  if (rawOutType != outType)
694  maybeCastBack =
695  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
696  rewriter.replaceOp(op, maybeCastBack->getResults());
697 
698  return success();
699  }
700 };
701 
702 namespace {
703 struct ExtPackedFp8OpLowering final
704  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
705  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
706  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
707  chipset(chipset) {}
708  Chipset chipset;
709 
710  LogicalResult
711  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
712  ConversionPatternRewriter &rewriter) const override;
713 };
714 
715 struct PackedTrunc2xFp8OpLowering final
716  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
717  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
718  Chipset chipset)
719  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
720  chipset(chipset) {}
721  Chipset chipset;
722 
723  LogicalResult
724  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
725  ConversionPatternRewriter &rewriter) const override;
726 };
727 
728 struct PackedStochRoundFp8OpLowering final
729  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
730  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
731  Chipset chipset)
732  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
733  chipset(chipset) {}
734  Chipset chipset;
735 
736  LogicalResult
737  matchAndRewrite(PackedStochRoundFp8Op op,
738  PackedStochRoundFp8OpAdaptor adaptor,
739  ConversionPatternRewriter &rewriter) const override;
740 };
741 } // end namespace
742 
743 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
744  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
745  ConversionPatternRewriter &rewriter) const {
746  Location loc = op.getLoc();
747  if (chipset.majorVersion != 9 || chipset < kGfx940)
748  return rewriter.notifyMatchFailure(
749  loc, "Fp8 conversion instructions are not available on target "
750  "architecture and their emulation is not implemented");
751  Type v4i8 =
752  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
753  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
754  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
755 
756  Value source = adaptor.getSource();
757  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
758  Type sourceElemType = getElementTypeOrSelf(op.getSource());
759  // Extend to a v4i8
760  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
761  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
762  if (!sourceVecType) {
763  longVec = rewriter.create<LLVM::InsertElementOp>(
764  loc, longVec, source, createI32Constant(rewriter, loc, 0));
765  } else {
766  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
767  Value idx = createI32Constant(rewriter, loc, i);
768  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
769  longVec =
770  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
771  }
772  }
773  source = longVec;
774  }
775  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
776  Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
777  if (sourceElemType.isFloat8E5M2FNUZ()) {
778  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
779  wordSel);
780  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
781  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
782  wordSel);
783  }
784  return success();
785 }
786 
787 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
788  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
789  ConversionPatternRewriter &rewriter) const {
790  Location loc = op.getLoc();
791  if (chipset.majorVersion != 9 || chipset < kGfx940)
792  return rewriter.notifyMatchFailure(
793  loc, "Fp8 conversion instructions are not available on target "
794  "architecture and their emulation is not implemented");
795  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
796 
797  Type resultType = op.getResult().getType();
798  Type resultElemType = getElementTypeOrSelf(resultType);
799 
800  Value sourceA = adaptor.getSourceA();
801  Value sourceB = adaptor.getSourceB();
802  if (!sourceB)
803  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
804  Value existing = adaptor.getExisting();
805  if (existing)
806  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
807  else
808  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
809  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
810 
811  Value result;
812  if (resultElemType.isFloat8E5M2FNUZ())
813  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
814  existing, wordSel);
815  else if (resultElemType.isFloat8E4M3FNUZ())
816  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
817  existing, wordSel);
818 
819  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
820  op, getTypeConverter()->convertType(resultType), result);
821  return success();
822 }
823 
824 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
825  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
826  ConversionPatternRewriter &rewriter) const {
827  Location loc = op.getLoc();
828  if (chipset.majorVersion != 9 || chipset < kGfx940)
829  return rewriter.notifyMatchFailure(
830  loc, "Fp8 conversion instructions are not available on target "
831  "architecture and their emulation is not implemented");
832  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
833 
834  Type resultType = op.getResult().getType();
835  Type resultElemType = getElementTypeOrSelf(resultType);
836 
837  Value source = adaptor.getSource();
838  Value stoch = adaptor.getStochiasticParam();
839  Value existing = adaptor.getExisting();
840  if (existing)
841  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
842  else
843  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
844  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
845 
846  Value result;
847  if (resultElemType.isFloat8E5M2FNUZ())
848  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
849  existing, byteSel);
850  else if (resultElemType.isFloat8E4M3FNUZ())
851  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
852  existing, byteSel);
853 
854  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
855  op, getTypeConverter()->convertType(resultType), result);
856  return success();
857 }
858 
859 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
860 // operation into the corresponding ROCDL instructions.
861 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
862  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
863  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
864  Chipset chipset;
865 
866  LogicalResult
867  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
868  ConversionPatternRewriter &rewriter) const override {
869 
870  // Convert the source operand to the corresponding LLVM type
871  Location loc = DppOp.getLoc();
872  Value src = adaptor.getSrc();
873  Value old = adaptor.getOld();
874  Type srcType = src.getType();
875  Type oldType = old.getType();
876  Type llvmType = nullptr;
877  if (srcType.getIntOrFloatBitWidth() < 32) {
878  llvmType = rewriter.getI32Type();
879  } else if (isa<FloatType>(srcType)) {
880  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
881  ? rewriter.getF32Type()
882  : rewriter.getF64Type();
883  } else if (isa<IntegerType>(srcType)) {
884  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
885  ? rewriter.getI32Type()
886  : rewriter.getI64Type();
887  }
888  auto llvmSrcIntType = typeConverter->convertType(
889  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
890 
891  // If the source type is less of 32, use bitcast to convert it to i32.
892  auto convertOperand = [&](Value operand, Type operandType) {
893  if (operandType.getIntOrFloatBitWidth() <= 16) {
894  if (llvm::isa<FloatType>(operandType)) {
895  operand =
896  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
897  }
898  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
899  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
900  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
901  operand = rewriter.create<LLVM::InsertElementOp>(
902  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
903  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
904  }
905  return operand;
906  };
907 
908  src = convertOperand(src, srcType);
909  old = convertOperand(old, oldType);
910 
911  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
912  enum DppCtrl : unsigned {
913  ROW_SHL0 = 0x100,
914  ROW_SHR0 = 0x110,
915  ROW_ROR0 = 0x120,
916  WAVE_SHL1 = 0x130,
917  WAVE_ROL1 = 0x134,
918  WAVE_SHR1 = 0x138,
919  WAVE_ROR1 = 0x13C,
920  ROW_MIRROR = 0x140,
921  ROW_HALF_MIRROR = 0x141,
922  BCAST15 = 0x142,
923  BCAST31 = 0x143,
924  };
925 
926  auto kind = DppOp.getKind();
927  auto permArgument = DppOp.getPermArgument();
928  uint32_t DppCtrl = 0;
929 
930  switch (kind) {
931 
932  case DPPPerm::quad_perm:
933  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
934  int32_t i = 0;
935  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
936  uint32_t num = elem.getInt();
937  DppCtrl |= num << (i * 2);
938  i++;
939  }
940  }
941  break;
942  case DPPPerm::row_shl:
943  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
944  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
945  }
946  break;
947  case DPPPerm::row_shr:
948  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
949  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
950  }
951  break;
952  case DPPPerm::row_ror:
953  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
954  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
955  }
956  break;
957  case DPPPerm::wave_shl:
958  DppCtrl = DppCtrl::WAVE_SHL1;
959  break;
960  case DPPPerm::wave_shr:
961  DppCtrl = DppCtrl::WAVE_SHR1;
962  break;
963  case DPPPerm::wave_rol:
964  DppCtrl = DppCtrl::WAVE_ROL1;
965  break;
966  case DPPPerm::wave_ror:
967  DppCtrl = DppCtrl::WAVE_ROR1;
968  break;
969  case DPPPerm::row_mirror:
970  DppCtrl = DppCtrl::ROW_MIRROR;
971  break;
972  case DPPPerm::row_half_mirror:
973  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
974  break;
975  case DPPPerm::row_bcast_15:
976  DppCtrl = DppCtrl::BCAST15;
977  break;
978  case DPPPerm::row_bcast_31:
979  DppCtrl = DppCtrl::BCAST31;
980  break;
981  }
982 
983  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
984  // constants
985  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
986  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
987  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
988 
989  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
990  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
991  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
992 
993  Value result = dppMovOp.getRes();
994  if (srcType.getIntOrFloatBitWidth() < 32) {
995  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
996  if (!llvm::isa<IntegerType>(srcType)) {
997  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
998  }
999  }
1000 
1001  // We are replacing the AMDGPU_DPPOp instruction with the new
1002  // ROCDL_DPPMovOp instruction
1003  rewriter.replaceOp(DppOp, ValueRange(result));
1004  return success();
1005  }
1006 };
1007 
1008 struct ConvertAMDGPUToROCDLPass
1009  : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1010  ConvertAMDGPUToROCDLPass() = default;
1011 
1012  void runOnOperation() override {
1013  MLIRContext *ctx = &getContext();
1014  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1015  if (failed(maybeChipset)) {
1016  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1017  return signalPassFailure();
1018  }
1019 
1020  RewritePatternSet patterns(ctx);
1021  LLVMTypeConverter converter(ctx);
1022  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1023  LLVMConversionTarget target(getContext());
1024  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1025  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1026  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1027  if (failed(applyPartialConversion(getOperation(), target,
1028  std::move(patterns))))
1029  signalPassFailure();
1030  }
1031 };
1032 } // namespace
1033 
1035  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1036  Chipset chipset) {
1037  patterns
1038  .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1039  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1040  RawBufferOpLowering<RawBufferAtomicFaddOp,
1041  ROCDL::RawPtrBufferAtomicFaddOp>,
1042  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1043  ROCDL::RawPtrBufferAtomicFmaxOp>,
1044  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1045  ROCDL::RawPtrBufferAtomicSmaxOp>,
1046  RawBufferOpLowering<RawBufferAtomicUminOp,
1047  ROCDL::RawPtrBufferAtomicUminOp>,
1048  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1049  ROCDL::RawPtrBufferAtomicCmpSwap>,
1050  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1051  MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1052  PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1053  chipset);
1054 }
1055 
1056 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
1057  return std::make_unique<ConvertAMDGPUToROCDLPass>();
1058 }
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerType getI16Type()
Definition: Builders.cpp:105
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:257
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IntegerType getI8Type()
Definition: Builders.cpp:103
FloatType getF64Type()
Definition: Builders.cpp:89
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:60
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isF32() const
Definition: Types.cpp:59
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:46
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isBF16() const
Definition: Types.cpp:56
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:43
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14