MLIR  20.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include <optional>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::amdgpu;
32 
34  Location loc, int32_t value) {
35  Type llvmI32 = rewriter.getI32Type();
36  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
37 }
38 
40  bool value) {
41  Type llvmI1 = rewriter.getI1Type();
42  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
43 }
44 
45 namespace {
46 // Define commonly used chipsets versions for convenience.
47 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49 constexpr Chipset kGfx940 = Chipset(9, 4, 0);
50 
51 /// Define lowering patterns for raw buffer ops
52 template <typename GpuOp, typename Intrinsic>
53 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
54  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
55  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
56 
57  Chipset chipset;
58  static constexpr uint32_t maxVectorOpWidth = 128;
59 
60  LogicalResult
61  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
62  ConversionPatternRewriter &rewriter) const override {
63  Location loc = gpuOp.getLoc();
64  Value memref = adaptor.getMemref();
65  Value unconvertedMemref = gpuOp.getMemref();
66  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
67 
68  if (chipset.majorVersion < 9)
69  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
70 
71  Value storeData = adaptor.getODSOperands(0)[0];
72  if (storeData == memref) // no write component to this op
73  storeData = Value();
74  Type wantedDataType;
75  if (storeData)
76  wantedDataType = storeData.getType();
77  else
78  wantedDataType = gpuOp.getODSResults(0)[0].getType();
79 
80  Value atomicCmpData = Value();
81  // Operand index 1 of a load is the indices, trying to read them can crash.
82  if (storeData) {
83  Value maybeCmpData = adaptor.getODSOperands(1)[0];
84  if (maybeCmpData != memref)
85  atomicCmpData = maybeCmpData;
86  }
87 
88  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
89 
90  Type i32 = rewriter.getI32Type();
91  Type llvmI32 = this->typeConverter->convertType(i32);
92  Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
93 
94  auto toI32 = [&](Value val) -> Value {
95  if (val.getType() == llvmI32)
96  return val;
97 
98  return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
99  };
100 
101  int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
102  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
103 
104  // If we want to load a vector<NxT> with total size <= 32
105  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
106  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
107  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
108  // so bitcast any floats to integers.
109  Type llvmBufferValType = llvmWantedDataType;
110  if (atomicCmpData) {
111  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
112  llvmBufferValType = this->getTypeConverter()->convertType(
113  rewriter.getIntegerType(floatType.getWidth()));
114  }
115  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
116  uint32_t vecLen = dataVector.getNumElements();
117  uint32_t elemBits = dataVector.getElementTypeBitWidth();
118  uint32_t totalBits = elemBits * vecLen;
119  bool usePackedFp16 =
120  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
121  if (totalBits > maxVectorOpWidth)
122  return gpuOp.emitOpError(
123  "Total width of loads or stores must be no more than " +
124  Twine(maxVectorOpWidth) + " bits, but we call for " +
125  Twine(totalBits) +
126  " bits. This should've been caught in validation");
127  if (!usePackedFp16 && elemBits < 32) {
128  if (totalBits > 32) {
129  if (totalBits % 32 != 0)
130  return gpuOp.emitOpError("Load or store of more than 32-bits that "
131  "doesn't fit into words. Can't happen\n");
132  llvmBufferValType = this->typeConverter->convertType(
133  VectorType::get(totalBits / 32, i32));
134  } else {
135  llvmBufferValType = this->typeConverter->convertType(
136  rewriter.getIntegerType(totalBits));
137  }
138  }
139  }
140 
142  if (storeData) {
143  if (llvmBufferValType != llvmWantedDataType) {
144  Value castForStore =
145  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
146  args.push_back(castForStore);
147  } else {
148  args.push_back(storeData);
149  }
150  }
151 
152  if (atomicCmpData) {
153  if (llvmBufferValType != llvmWantedDataType) {
154  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
155  loc, llvmBufferValType, atomicCmpData);
156  args.push_back(castForCmp);
157  } else {
158  args.push_back(atomicCmpData);
159  }
160  }
161 
162  // Construct buffer descriptor from memref, attributes
163  int64_t offset = 0;
164  SmallVector<int64_t, 5> strides;
165  if (failed(getStridesAndOffset(memrefType, strides, offset)))
166  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
167 
168  MemRefDescriptor memrefDescriptor(memref);
169 
170  Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
171  // The stride value is always 0 for raw buffers. This also disables
172  // swizling.
173  Value stride = rewriter.create<LLVM::ConstantOp>(
174  loc, llvmI16, rewriter.getI16IntegerAttr(0));
175  Value numRecords;
176  if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
177  numRecords = createI32Constant(
178  rewriter, loc,
179  static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
180  } else {
181  Value maxIndex;
182  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
183  Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
184  Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
185  stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
186  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
187  maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
188  maxThisDim)
189  : maxThisDim;
190  }
191  numRecords = maxIndex;
192  }
193 
194  // Flag word:
195  // bits 0-11: dst sel, ignored by these intrinsics
196  // bits 12-14: data format (ignored, must be nonzero, 7=float)
197  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
198  // bit 19: In nested heap (0 here)
199  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
200  // bits 21-22: Index stride for swizzles (N/A)
201  // bit 23: Add thread ID (0)
202  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
203  // bits 25-26: Reserved (0)
204  // bit 27: Buffer is non-volatile (CDNA only)
205  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
206  // none, 3 = either swizzles or testing against offset field) RDNA only
207  // bits 30-31: Type (must be 0)
208  uint32_t flags = (7 << 12) | (4 << 15);
209  if (chipset.majorVersion >= 10) {
210  flags |= (1 << 24);
211  uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
212  flags |= (oob << 28);
213  }
214  Value flagsConst = createI32Constant(rewriter, loc, flags);
215  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
216  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
217  loc, rsrcType, ptr, stride, numRecords, flagsConst);
218  args.push_back(resource);
219 
220  // Indexing (voffset)
221  Value voffset = createI32Constant(rewriter, loc, 0);
222  for (auto pair : llvm::enumerate(adaptor.getIndices())) {
223  size_t i = pair.index();
224  Value index = pair.value();
225  Value strideOp;
226  if (ShapedType::isDynamic(strides[i])) {
227  strideOp = rewriter.create<LLVM::MulOp>(
228  loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
229  byteWidthConst);
230  } else {
231  strideOp =
232  createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
233  }
234  index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
235  voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
236  }
237  if (adaptor.getIndexOffset()) {
238  int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
239  Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
240  voffset =
241  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
242  : extraOffsetConst;
243  }
244  args.push_back(voffset);
245 
246  Value sgprOffset = adaptor.getSgprOffset();
247  if (!sgprOffset)
248  sgprOffset = createI32Constant(rewriter, loc, 0);
249  if (ShapedType::isDynamic(offset))
250  sgprOffset = rewriter.create<LLVM::AddOp>(
251  loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
252  else if (offset > 0)
253  sgprOffset = rewriter.create<LLVM::AddOp>(
254  loc, sgprOffset, createI32Constant(rewriter, loc, offset));
255  args.push_back(sgprOffset);
256 
257  // bit 0: GLC = 0 (atomics drop value, less coherency)
258  // bits 1-2: SLC, DLC = 0 (similarly)
259  // bit 3: swizzled (0 for raw)
260  args.push_back(createI32Constant(rewriter, loc, 0));
261 
262  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
263  llvmBufferValType);
264  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
266  if (lowered->getNumResults() == 1) {
267  Value replacement = lowered->getResult(0);
268  if (llvmBufferValType != llvmWantedDataType) {
269  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
270  replacement);
271  }
272  rewriter.replaceOp(gpuOp, replacement);
273  } else {
274  rewriter.eraseOp(gpuOp);
275  }
276  return success();
277  }
278 };
279 
280 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
281  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
282  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
283 
284  Chipset chipset;
285 
286  LogicalResult
287  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
288  ConversionPatternRewriter &rewriter) const override {
289  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
290 
291  if (requiresInlineAsm) {
292  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
293  LLVM::AsmDialect::AD_ATT);
294  const char *asmStr =
295  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
296  const char *constraints = "";
297  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
298  op,
299  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
300  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
301  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
302  /*operand_attrs=*/ArrayAttr());
303  return success();
304  }
305  if (chipset.majorVersion < 12) {
306  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
307  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
308  // Left in place in case someone disables the inline ASM path or future
309  // chipsets use the same bit pattern.
310  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
311 
312  int32_t ldsOnlyBits;
313  if (chipset.majorVersion == 11)
314  ldsOnlyBits = ldsOnlyBitsGfx11;
315  else if (chipset.majorVersion == 10)
316  ldsOnlyBits = ldsOnlyBitsGfx10;
317  else if (chipset.majorVersion <= 9)
318  ldsOnlyBits = ldsOnlyBitsGfx6789;
319  else
320  return op.emitOpError(
321  "don't know how to lower this for chipset major version")
322  << chipset.majorVersion;
323 
324  Location loc = op->getLoc();
325  rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
326  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
327  } else {
328  Location loc = op->getLoc();
329  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
330  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
331  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
332  }
333 
334  return success();
335  }
336 };
337 
338 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
339  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
340  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
341 
342  Chipset chipset;
343 
344  LogicalResult
345  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
346  ConversionPatternRewriter &rewriter) const override {
347  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
348  (uint32_t)op.getOpts());
349  return success();
350  }
351 };
352 
353 } // namespace
354 
355 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
356 /// and LLVM AMDGPU intrinsics convention.
357 ///
358 /// Specifically:
359 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
360 /// 2. If the element type is bfloat16, bitcast it to i16.
362  Location loc, Value input) {
363  Type inputType = input.getType();
364  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
365  if (vectorType.getElementType().isBF16())
366  return rewriter.create<LLVM::BitcastOp>(
367  loc, vectorType.clone(rewriter.getI16Type()), input);
368  if (vectorType.getElementType().isInteger(8)) {
369  return rewriter.create<LLVM::BitcastOp>(
370  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
371  }
372  }
373  return input;
374 }
375 
376 /// Push an input operand. If it is a float type, nothing to do. If it is
377 /// an integer type, then we need to also push its signdness (1 for signed, 0
378 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
379 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
380 /// of bfloat support in the WMMA intrinsics themselves.
382  Location loc,
383  const TypeConverter *typeConverter,
384  bool isUnsigned, Value llvmInput,
385  Value mlirInput,
386  SmallVector<Value, 4> &operands) {
387  Type inputType = llvmInput.getType();
388  auto vectorType = dyn_cast<VectorType>(inputType);
389  Type elemType = vectorType.getElementType();
390 
391  if (elemType.isBF16())
392  llvmInput = rewriter.create<LLVM::BitcastOp>(
393  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
394  if (!elemType.isInteger(8)) {
395  operands.push_back(llvmInput);
396  return;
397  }
398 
399  // We need to check the type of the input before conversion to properly test
400  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
401  // fp8/int8 information is lost during the conversion process.
402  auto mlirInputType = cast<VectorType>(mlirInput.getType());
403  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
404  if (isInputInt8) {
405  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
406  bool localIsUnsigned = isUnsigned;
407  if (elemType.isUnsignedInteger(8)) {
408  localIsUnsigned = true;
409  } else if (elemType.isSignedInteger(8)) {
410  localIsUnsigned = false;
411  }
412  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
413  operands.push_back(sign);
414  }
415 
416  int64_t numBytes = vectorType.getNumElements();
417  Type i32 = rewriter.getI32Type();
418  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
419  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
420  Value result = rewriter.createOrFold<LLVM::BitcastOp>(
421  loc, llvmVectorType32bits, llvmInput);
422  operands.push_back(result);
423 }
424 
425 /// Push the output operand. For many cases this is only pushing the output in
426 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
427 /// since the same numbers of VGPRs is used, we need to decide if to store the
428 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
429 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
430 /// be stored it in the upper part
432  Location loc,
433  const TypeConverter *typeConverter,
434  Value output, int32_t subwordOffset,
435  bool clamp, SmallVector<Value, 4> &operands) {
436  Type inputType = output.getType();
437  auto vectorType = dyn_cast<VectorType>(inputType);
438  Type elemType = vectorType.getElementType();
439  if (elemType.isBF16())
440  output = rewriter.create<LLVM::BitcastOp>(
441  loc, vectorType.clone(rewriter.getI16Type()), output);
442  operands.push_back(output);
443  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
444  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
445  } else if (elemType.isInteger(32)) {
446  operands.push_back(createI1Constant(rewriter, loc, clamp));
447  }
448 }
449 
450 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
451 /// if one exists. This includes checking to ensure the intrinsic is supported
452 /// on the architecture you are compiling for.
453 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
454  Chipset chipset) {
455  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
456  b = mfma.getBlocks();
457  Type sourceElem = mfma.getSourceA().getType();
458  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
459  sourceElem = sourceType.getElementType();
460  Type destElem = mfma.getDestC().getType();
461  if (auto destType = dyn_cast<VectorType>(destElem))
462  destElem = destType.getElementType();
463 
464  if (sourceElem.isF32() && destElem.isF32()) {
465  if (mfma.getReducePrecision() && chipset >= kGfx940) {
466  if (m == 32 && n == 32 && k == 4 && b == 1)
467  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
468  if (m == 16 && n == 16 && k == 8 && b == 1)
469  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
470  }
471  if (m == 32 && n == 32 && k == 1 && b == 2)
472  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
473  if (m == 16 && n == 16 && k == 1 && b == 4)
474  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
475  if (m == 4 && n == 4 && k == 1 && b == 16)
476  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
477  if (m == 32 && n == 32 && k == 2 && b == 1)
478  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
479  if (m == 16 && n == 16 && k == 4 && b == 1)
480  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
481  }
482 
483  if (sourceElem.isF16() && destElem.isF32()) {
484  if (m == 32 && n == 32 && k == 4 && b == 2)
485  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
486  if (m == 16 && n == 16 && k == 4 && b == 4)
487  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
488  if (m == 4 && n == 4 && k == 4 && b == 16)
489  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
490  if (m == 32 && n == 32 && k == 8 && b == 1)
491  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
492  if (m == 16 && n == 16 && k == 16 && b == 1)
493  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
494  }
495 
496  if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
497  if (m == 32 && n == 32 && k == 4 && b == 2)
498  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
499  if (m == 16 && n == 16 && k == 4 && b == 4)
500  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
501  if (m == 4 && n == 4 && k == 4 && b == 16)
502  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
503  if (m == 32 && n == 32 && k == 8 && b == 1)
504  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
505  if (m == 16 && n == 16 && k == 16 && b == 1)
506  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
507  }
508 
509  if (sourceElem.isBF16() && destElem.isF32()) {
510  if (m == 32 && n == 32 && k == 2 && b == 2)
511  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
512  if (m == 16 && n == 16 && k == 2 && b == 4)
513  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
514  if (m == 4 && n == 4 && k == 2 && b == 16)
515  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
516  if (m == 32 && n == 32 && k == 4 && b == 1)
517  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
518  if (m == 16 && n == 16 && k == 8 && b == 1)
519  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
520  }
521 
522  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
523  if (m == 32 && n == 32 && k == 4 && b == 2)
524  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
525  if (m == 16 && n == 16 && k == 4 && b == 4)
526  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
527  if (m == 4 && n == 4 && k == 4 && b == 16)
528  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
529  if (m == 32 && n == 32 && k == 8 && b == 1)
530  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
531  if (m == 16 && n == 16 && k == 16 && b == 1)
532  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
533  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
534  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
535  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
536  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
537  }
538 
539  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
540  if (m == 16 && n == 16 && k == 4 && b == 1)
541  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
542  if (m == 4 && n == 4 && k == 4 && b == 4)
543  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
544  }
545 
546  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
547  // Known to be correct because there are no scalar f8 instructions and
548  // because a length mismatch will have been caught by the verifier.
549  Type sourceBElem =
550  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
551  if (m == 16 && n == 16 && k == 32 && b == 1) {
552  if (sourceBElem.isFloat8E5M2FNUZ())
553  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
554  if (sourceBElem.isFloat8E4M3FNUZ())
555  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
556  }
557  if (m == 32 && n == 32 && k == 16 && b == 1) {
558  if (sourceBElem.isFloat8E5M2FNUZ())
559  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
560  if (sourceBElem.isFloat8E4M3FNUZ())
561  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
562  }
563  }
564 
565  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
566  Type sourceBElem =
567  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
568  if (m == 16 && n == 16 && k == 32 && b == 1) {
569  if (sourceBElem.isFloat8E5M2FNUZ())
570  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
571  if (sourceBElem.isFloat8E4M3FNUZ())
572  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
573  }
574  if (m == 32 && n == 32 && k == 16 && b == 1) {
575  if (sourceBElem.isFloat8E5M2FNUZ())
576  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
577  if (sourceBElem.isFloat8E4M3FNUZ())
578  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
579  }
580  }
581 
582  return std::nullopt;
583 }
584 
585 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
586 /// if one exists. This includes checking to ensure the intrinsic is supported
587 /// on the architecture you are compiling for.
588 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
589  Chipset chipset) {
590  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
591  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
592  auto elemSourceType = sourceVectorType.getElementType();
593  auto elemDestType = destVectorType.getElementType();
594 
595  if (elemSourceType.isF16() && elemDestType.isF32())
596  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
597  if (elemSourceType.isBF16() && elemDestType.isF32())
598  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
599  if (elemSourceType.isF16() && elemDestType.isF16())
600  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
601  if (elemSourceType.isBF16() && elemDestType.isBF16())
602  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
603  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
604  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
605  if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
606  return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
607  if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
608  return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
609  return std::nullopt;
610 }
611 
612 namespace {
613 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
614  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
615  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
616 
617  Chipset chipset;
618 
619  LogicalResult
620  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const override {
622  Location loc = op.getLoc();
623  Type outType = typeConverter->convertType(op.getDestD().getType());
624  Type intrinsicOutType = outType;
625  if (auto outVecType = dyn_cast<VectorType>(outType))
626  if (outVecType.getElementType().isBF16())
627  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
628 
629  if (chipset.majorVersion != 9 || chipset < kGfx908)
630  return op->emitOpError("MFMA only supported on gfx908+");
631  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
632  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
633  if (chipset < kGfx940)
634  return op.emitOpError("negation unsupported on older than gfx940");
635  getBlgpField |=
636  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
637  }
638  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
639  if (!maybeIntrinsic.has_value())
640  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
641  OperationState loweredOp(loc, *maybeIntrinsic);
642  loweredOp.addTypes(intrinsicOutType);
643  loweredOp.addOperands(
644  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
645  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
646  adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
647  createI32Constant(rewriter, loc, op.getAbid()),
648  createI32Constant(rewriter, loc, getBlgpField)});
649  Value lowered = rewriter.create(loweredOp)->getResult(0);
650  if (outType != intrinsicOutType)
651  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
652  rewriter.replaceOp(op, lowered);
653  return success();
654  }
655 };
656 
657 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
658  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
659  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
660 
661  Chipset chipset;
662 
663  LogicalResult
664  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
665  ConversionPatternRewriter &rewriter) const override {
666  Location loc = op.getLoc();
667  auto outType =
668  typeConverter->convertType<VectorType>(op.getDestD().getType());
669  if (!outType)
670  return rewriter.notifyMatchFailure(op, "type conversion failed");
671 
672  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
673  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
674 
675  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
676  // need to bitcast bfloats to i16 and then bitcast them back.
677  VectorType rawOutType = outType;
678  if (outType.getElementType().isBF16())
679  rawOutType = outType.clone(rewriter.getI16Type());
680 
681  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
682 
683  if (!maybeIntrinsic.has_value())
684  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
685 
686  OperationState loweredOp(loc, *maybeIntrinsic);
687  loweredOp.addTypes(rawOutType);
688 
689  SmallVector<Value, 4> operands;
690  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
691  adaptor.getSourceA(), op.getSourceA(), operands);
692  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
693  adaptor.getSourceB(), op.getSourceB(), operands);
694  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
695  op.getSubwordOffset(), op.getClamp(), operands);
696 
697  loweredOp.addOperands(operands);
698  Operation *lowered = rewriter.create(loweredOp);
699 
700  Operation *maybeCastBack = lowered;
701  if (rawOutType != outType)
702  maybeCastBack =
703  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
704  rewriter.replaceOp(op, maybeCastBack->getResults());
705 
706  return success();
707  }
708 };
709 
710 namespace {
711 struct ExtPackedFp8OpLowering final
712  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
713  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
714  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
715  chipset(chipset) {}
716  Chipset chipset;
717 
718  LogicalResult
719  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
720  ConversionPatternRewriter &rewriter) const override;
721 };
722 
723 struct PackedTrunc2xFp8OpLowering final
724  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
725  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
726  Chipset chipset)
727  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
728  chipset(chipset) {}
729  Chipset chipset;
730 
731  LogicalResult
732  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
733  ConversionPatternRewriter &rewriter) const override;
734 };
735 
736 struct PackedStochRoundFp8OpLowering final
737  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
738  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
739  Chipset chipset)
740  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
741  chipset(chipset) {}
742  Chipset chipset;
743 
744  LogicalResult
745  matchAndRewrite(PackedStochRoundFp8Op op,
746  PackedStochRoundFp8OpAdaptor adaptor,
747  ConversionPatternRewriter &rewriter) const override;
748 };
749 } // end namespace
750 
751 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
752  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
753  ConversionPatternRewriter &rewriter) const {
754  Location loc = op.getLoc();
755  if (chipset.majorVersion != 9 || chipset < kGfx940)
756  return rewriter.notifyMatchFailure(
757  loc, "Fp8 conversion instructions are not available on target "
758  "architecture and their emulation is not implemented");
759  Type v4i8 =
760  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
761  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
762  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
763 
764  Value source = adaptor.getSource();
765  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
766  Type sourceElemType = getElementTypeOrSelf(op.getSource());
767  // Extend to a v4i8
768  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
769  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
770  if (!sourceVecType) {
771  longVec = rewriter.create<LLVM::InsertElementOp>(
772  loc, longVec, source, createI32Constant(rewriter, loc, 0));
773  } else {
774  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
775  Value idx = createI32Constant(rewriter, loc, i);
776  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
777  longVec =
778  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
779  }
780  }
781  source = longVec;
782  }
783  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
784  Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
785  if (sourceElemType.isFloat8E5M2FNUZ()) {
786  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
787  wordSel);
788  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
789  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
790  wordSel);
791  }
792  return success();
793 }
794 
795 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
796  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
797  ConversionPatternRewriter &rewriter) const {
798  Location loc = op.getLoc();
799  if (chipset.majorVersion != 9 || chipset < kGfx940)
800  return rewriter.notifyMatchFailure(
801  loc, "Fp8 conversion instructions are not available on target "
802  "architecture and their emulation is not implemented");
803  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
804 
805  Type resultType = op.getResult().getType();
806  Type resultElemType = getElementTypeOrSelf(resultType);
807 
808  Value sourceA = adaptor.getSourceA();
809  Value sourceB = adaptor.getSourceB();
810  if (!sourceB)
811  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
812  Value existing = adaptor.getExisting();
813  if (existing)
814  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
815  else
816  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
817  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
818 
819  Value result;
820  if (resultElemType.isFloat8E5M2FNUZ())
821  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
822  existing, wordSel);
823  else if (resultElemType.isFloat8E4M3FNUZ())
824  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
825  existing, wordSel);
826 
827  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
828  op, getTypeConverter()->convertType(resultType), result);
829  return success();
830 }
831 
832 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
833  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
834  ConversionPatternRewriter &rewriter) const {
835  Location loc = op.getLoc();
836  if (chipset.majorVersion != 9 || chipset < kGfx940)
837  return rewriter.notifyMatchFailure(
838  loc, "Fp8 conversion instructions are not available on target "
839  "architecture and their emulation is not implemented");
840  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
841 
842  Type resultType = op.getResult().getType();
843  Type resultElemType = getElementTypeOrSelf(resultType);
844 
845  Value source = adaptor.getSource();
846  Value stoch = adaptor.getStochiasticParam();
847  Value existing = adaptor.getExisting();
848  if (existing)
849  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
850  else
851  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
852  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
853 
854  Value result;
855  if (resultElemType.isFloat8E5M2FNUZ())
856  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
857  existing, byteSel);
858  else if (resultElemType.isFloat8E4M3FNUZ())
859  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
860  existing, byteSel);
861 
862  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
863  op, getTypeConverter()->convertType(resultType), result);
864  return success();
865 }
866 
867 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
868 // operation into the corresponding ROCDL instructions.
869 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
870  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
871  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
872  Chipset chipset;
873 
874  LogicalResult
875  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
876  ConversionPatternRewriter &rewriter) const override {
877 
878  // Convert the source operand to the corresponding LLVM type
879  Location loc = DppOp.getLoc();
880  Value src = adaptor.getSrc();
881  Value old = adaptor.getOld();
882  Type srcType = src.getType();
883  Type oldType = old.getType();
884  Type llvmType = nullptr;
885  if (srcType.getIntOrFloatBitWidth() < 32) {
886  llvmType = rewriter.getI32Type();
887  } else if (isa<FloatType>(srcType)) {
888  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
889  ? rewriter.getF32Type()
890  : rewriter.getF64Type();
891  } else if (isa<IntegerType>(srcType)) {
892  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
893  ? rewriter.getI32Type()
894  : rewriter.getI64Type();
895  }
896  auto llvmSrcIntType = typeConverter->convertType(
897  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
898 
899  // If the source type is less of 32, use bitcast to convert it to i32.
900  auto convertOperand = [&](Value operand, Type operandType) {
901  if (operandType.getIntOrFloatBitWidth() <= 16) {
902  if (llvm::isa<FloatType>(operandType)) {
903  operand =
904  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
905  }
906  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
907  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
908  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
909  operand = rewriter.create<LLVM::InsertElementOp>(
910  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
911  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
912  }
913  return operand;
914  };
915 
916  src = convertOperand(src, srcType);
917  old = convertOperand(old, oldType);
918 
919  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
920  enum DppCtrl : unsigned {
921  ROW_SHL0 = 0x100,
922  ROW_SHR0 = 0x110,
923  ROW_ROR0 = 0x120,
924  WAVE_SHL1 = 0x130,
925  WAVE_ROL1 = 0x134,
926  WAVE_SHR1 = 0x138,
927  WAVE_ROR1 = 0x13C,
928  ROW_MIRROR = 0x140,
929  ROW_HALF_MIRROR = 0x141,
930  BCAST15 = 0x142,
931  BCAST31 = 0x143,
932  };
933 
934  auto kind = DppOp.getKind();
935  auto permArgument = DppOp.getPermArgument();
936  uint32_t DppCtrl = 0;
937 
938  switch (kind) {
939 
940  case DPPPerm::quad_perm:
941  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
942  int32_t i = 0;
943  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
944  uint32_t num = elem.getInt();
945  DppCtrl |= num << (i * 2);
946  i++;
947  }
948  }
949  break;
950  case DPPPerm::row_shl:
951  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
952  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
953  }
954  break;
955  case DPPPerm::row_shr:
956  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
957  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
958  }
959  break;
960  case DPPPerm::row_ror:
961  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
962  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
963  }
964  break;
965  case DPPPerm::wave_shl:
966  DppCtrl = DppCtrl::WAVE_SHL1;
967  break;
968  case DPPPerm::wave_shr:
969  DppCtrl = DppCtrl::WAVE_SHR1;
970  break;
971  case DPPPerm::wave_rol:
972  DppCtrl = DppCtrl::WAVE_ROL1;
973  break;
974  case DPPPerm::wave_ror:
975  DppCtrl = DppCtrl::WAVE_ROR1;
976  break;
977  case DPPPerm::row_mirror:
978  DppCtrl = DppCtrl::ROW_MIRROR;
979  break;
980  case DPPPerm::row_half_mirror:
981  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
982  break;
983  case DPPPerm::row_bcast_15:
984  DppCtrl = DppCtrl::BCAST15;
985  break;
986  case DPPPerm::row_bcast_31:
987  DppCtrl = DppCtrl::BCAST31;
988  break;
989  }
990 
991  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
992  // constants
993  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
994  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
995  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
996 
997  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
998  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
999  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1000 
1001  Value result = dppMovOp.getRes();
1002  if (srcType.getIntOrFloatBitWidth() < 32) {
1003  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1004  if (!llvm::isa<IntegerType>(srcType)) {
1005  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1006  }
1007  }
1008 
1009  // We are replacing the AMDGPU_DPPOp instruction with the new
1010  // ROCDL_DPPMovOp instruction
1011  rewriter.replaceOp(DppOp, ValueRange(result));
1012  return success();
1013  }
1014 };
1015 
1016 struct ConvertAMDGPUToROCDLPass
1017  : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1018  ConvertAMDGPUToROCDLPass() = default;
1019 
1020  void runOnOperation() override {
1021  MLIRContext *ctx = &getContext();
1022  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1023  if (failed(maybeChipset)) {
1024  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1025  return signalPassFailure();
1026  }
1027 
1029  LLVMTypeConverter converter(ctx);
1030  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1031  LLVMConversionTarget target(getContext());
1032  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1033  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1034  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1035  if (failed(applyPartialConversion(getOperation(), target,
1036  std::move(patterns))))
1037  signalPassFailure();
1038  }
1039 };
1040 } // namespace
1041 
1043  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1044  Chipset chipset) {
1045  patterns
1046  .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1047  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1048  RawBufferOpLowering<RawBufferAtomicFaddOp,
1049  ROCDL::RawPtrBufferAtomicFaddOp>,
1050  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1051  ROCDL::RawPtrBufferAtomicFmaxOp>,
1052  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1053  ROCDL::RawPtrBufferAtomicSmaxOp>,
1054  RawBufferOpLowering<RawBufferAtomicUminOp,
1055  ROCDL::RawPtrBufferAtomicUminOp>,
1056  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1057  ROCDL::RawPtrBufferAtomicCmpSwap>,
1058  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1059  MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1060  PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1061  chipset);
1062 }
1063 
1064 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
1065  return std::make_unique<ConvertAMDGPUToROCDLPass>();
1066 }
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerType getI16Type()
Definition: Builders.cpp:105
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:257
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
IntegerType getI8Type()
Definition: Builders.cpp:103
FloatType getF64Type()
Definition: Builders.cpp:89
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:60
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isF32() const
Definition: Types.cpp:59
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:46
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isBF16() const
Definition: Types.cpp:56
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:43
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14