MLIR  21.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "../LLVMCommon/MemRefDescriptor.h"
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <optional>
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::amdgpu;
35 
36 // Define commonly used chipsets versions for convenience.
37 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
38 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
39 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
40 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
41 
42 /// Convert an unsigned number `val` to i32.
44  Location loc, Value val) {
45  IntegerType i32 = rewriter.getI32Type();
46  // Force check that `val` is of int type.
47  auto valTy = cast<IntegerType>(val.getType());
48  if (i32 == valTy)
49  return val;
50  return valTy.getWidth() > 32
51  ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
52  : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
53 }
54 
56  Location loc, int32_t value) {
57  Type i32 = rewriter.getI32Type();
58  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
59 }
60 
62  bool value) {
63  Type llvmI1 = rewriter.getI1Type();
64  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
65 }
66 
67 /// Returns the linear index used to access an element in the memref.
69  Location loc, MemRefDescriptor &memRefDescriptor,
70  ValueRange indices, ArrayRef<int64_t> strides) {
71  IntegerType i32 = rewriter.getI32Type();
72  Value index;
73  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
74  if (stride != 1) { // Skip if stride is 1.
75  Value strideValue =
76  ShapedType::isDynamic(stride)
77  ? convertUnsignedToI32(rewriter, loc,
78  memRefDescriptor.stride(rewriter, loc, i))
79  : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
80  increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
81  }
82  index =
83  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
84  }
85  return index ? index : createI32Constant(rewriter, loc, 0);
86 }
87 
88 /// Compute the contents of the `num_records` field for a given memref
89 /// descriptor - that is, the number of bytes that's one element past the
90 /// greatest possible valid index into the memref.
92  MemRefType memrefType,
93  MemRefDescriptor &memrefDescriptor,
94  ArrayRef<int64_t> strides,
95  uint32_t elementByteWidth) {
96  if (memrefType.hasStaticShape() &&
97  !llvm::any_of(strides, ShapedType::isDynamic)) {
98  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
99  ArrayRef<int64_t> shape = memrefType.getShape();
100  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
101  size = std::max(shape[i] * strides[i], size);
102  size = size * elementByteWidth;
103  assert(size < std::numeric_limits<uint32_t>::max() &&
104  "the memref buffer is too large");
105  return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
106  }
107  Value maxIndex;
108  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
109  Value size = memrefDescriptor.size(rewriter, loc, i);
110  Value stride = memrefDescriptor.stride(rewriter, loc, i);
111  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
112  maxIndex = maxIndex
113  ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
114  : maxThisDim;
115  }
116  Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
117  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
118  return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
119 }
120 
122  Value basePointer, Value numRecords,
123  bool boundsCheck, amdgpu::Chipset chipset,
124  Value cacheSwizzleStride = nullptr,
125  unsigned addressSpace = 8) {
126  // The stride value is generally 0. However, on MI-300 and onward, you can
127  // enable a cache swizzling mode by setting bit 14 of the stride field
128  // and setting that stride to a cache stride.
129  Type i16 = rewriter.getI16Type();
130  Value stride;
131  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
132  Value cacheStrideZext =
133  rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
134  Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
135  loc, i16, rewriter.getI16IntegerAttr(1 << 14));
136  stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
137  /*isDisjoint=*/true);
138  } else {
139  stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
140  rewriter.getI16IntegerAttr(0));
141  }
142  // Get the number of elements.
143  // Flag word:
144  // bits 0-11: dst sel, ignored by these intrinsics
145  // bits 12-14: data format (ignored, must be nonzero, 7=float)
146  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
147  // bit 19: In nested heap (0 here)
148  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
149  // bits 21-22: Index stride for swizzles (N/A)
150  // bit 23: Add thread ID (0)
151  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
152  // bits 25-26: Reserved (0)
153  // bit 27: Buffer is non-volatile (CDNA only)
154  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
155  // none, 3 = either swizzles or testing against offset field) RDNA only
156  // bits 30-31: Type (must be 0)
157  uint32_t flags = (7 << 12) | (4 << 15);
158  if (chipset.majorVersion >= 10) {
159  flags |= (1 << 24);
160  uint32_t oob = boundsCheck ? 3 : 2;
161  flags |= (oob << 28);
162  }
163  Value flagsConst = createI32Constant(rewriter, loc, flags);
164  Type rsrcType =
165  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
166  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
167  loc, rsrcType, basePointer, stride, numRecords, flagsConst);
168  return resource;
169 }
170 
171 namespace {
172 struct FatRawBufferCastLowering
173  : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
174  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
175  : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
176  chipset(chipset) {}
177 
178  Chipset chipset;
179 
180  LogicalResult
181  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override {
183  Location loc = op.getLoc();
184  Value memRef = adaptor.getSource();
185  Value unconvertedMemref = op.getSource();
186  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
187  MemRefDescriptor descriptor(memRef);
188 
189  DataLayout dataLayout = DataLayout::closest(op);
190  int64_t elementByteWidth =
191  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
192 
193  int64_t unusedOffset = 0;
194  SmallVector<int64_t, 5> strideVals;
195  if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
196  return op.emitOpError("Can't lower non-stride-offset memrefs");
197 
198  Value numRecords = adaptor.getValidBytes();
199  if (!numRecords)
200  numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
201  strideVals, elementByteWidth);
202 
203  Value basePointer =
204  adaptor.getResetOffset()
205  ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
206  memrefType)
207  : descriptor.alignedPtr(rewriter, loc);
208 
209  Value offset = adaptor.getResetOffset()
210  ? rewriter.create<LLVM::ConstantOp>(
211  loc, getIndexType(), rewriter.getIndexAttr(0))
212  : descriptor.offset(rewriter, loc);
213 
214  bool hasSizes = memrefType.getRank() > 0;
215  // No need to unpack() and pack() all the individual sizes and strides,
216  // so we'll just extract the arrays.
217  Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
218  loc, descriptor, kSizePosInMemRefDescriptor)
219  : Value{};
220  Value strides = hasSizes
221  ? rewriter.create<LLVM::ExtractValueOp>(
222  loc, descriptor, kStridePosInMemRefDescriptor)
223  : Value{};
224 
225  Value fatPtr = makeBufferRsrc(
226  rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
227  chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
228 
230  rewriter, loc,
231  getTypeConverter()->convertType(op.getResult().getType()));
232  result = rewriter.create<LLVM::InsertValueOp>(
233  loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
234  result = rewriter.create<LLVM::InsertValueOp>(
235  loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
236  result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
238  if (hasSizes) {
239  result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
241  result = rewriter.create<LLVM::InsertValueOp>(
242  loc, result, strides, kStridePosInMemRefDescriptor);
243  }
244  rewriter.replaceOp(op, result);
245  return success();
246  }
247 };
248 
249 /// Define lowering patterns for raw buffer ops
250 template <typename GpuOp, typename Intrinsic>
251 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
252  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
253  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
254 
255  Chipset chipset;
256  static constexpr uint32_t maxVectorOpWidth = 128;
257 
258  LogicalResult
259  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
260  ConversionPatternRewriter &rewriter) const override {
261  Location loc = gpuOp.getLoc();
262  Value memref = adaptor.getMemref();
263  Value unconvertedMemref = gpuOp.getMemref();
264  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
265 
266  if (chipset.majorVersion < 9)
267  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
268 
269  Value storeData = adaptor.getODSOperands(0)[0];
270  if (storeData == memref) // no write component to this op
271  storeData = Value();
272  Type wantedDataType;
273  if (storeData)
274  wantedDataType = storeData.getType();
275  else
276  wantedDataType = gpuOp.getODSResults(0)[0].getType();
277 
278  Value atomicCmpData = Value();
279  // Operand index 1 of a load is the indices, trying to read them can crash.
280  if (storeData) {
281  Value maybeCmpData = adaptor.getODSOperands(1)[0];
282  if (maybeCmpData != memref)
283  atomicCmpData = maybeCmpData;
284  }
285 
286  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
287 
288  Type i32 = rewriter.getI32Type();
289 
290  // Get the type size in bytes.
291  DataLayout dataLayout = DataLayout::closest(gpuOp);
292  int64_t elementByteWidth =
293  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
294  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
295 
296  // If we want to load a vector<NxT> with total size <= 32
297  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
298  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
299  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
300  // so bitcast any floats to integers.
301  Type llvmBufferValType = llvmWantedDataType;
302  if (atomicCmpData) {
303  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
304  llvmBufferValType = this->getTypeConverter()->convertType(
305  rewriter.getIntegerType(floatType.getWidth()));
306  }
307  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
308  uint32_t vecLen = dataVector.getNumElements();
309  uint32_t elemBits =
310  dataLayout.getTypeSizeInBits(dataVector.getElementType());
311  uint32_t totalBits = elemBits * vecLen;
312  bool usePackedFp16 =
313  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
314  if (totalBits > maxVectorOpWidth)
315  return gpuOp.emitOpError(
316  "Total width of loads or stores must be no more than " +
317  Twine(maxVectorOpWidth) + " bits, but we call for " +
318  Twine(totalBits) +
319  " bits. This should've been caught in validation");
320  if (!usePackedFp16 && elemBits < 32) {
321  if (totalBits > 32) {
322  if (totalBits % 32 != 0)
323  return gpuOp.emitOpError("Load or store of more than 32-bits that "
324  "doesn't fit into words. Can't happen\n");
325  llvmBufferValType = this->typeConverter->convertType(
326  VectorType::get(totalBits / 32, i32));
327  } else {
328  llvmBufferValType = this->typeConverter->convertType(
329  rewriter.getIntegerType(totalBits));
330  }
331  }
332  }
333  if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
334  // Buffer intrinsics doesn't support 1-element vectors, cast them to
335  // scalars.
336  if (vecType.getNumElements() == 1)
337  llvmBufferValType = vecType.getElementType();
338  }
339 
341  if (storeData) {
342  if (llvmBufferValType != llvmWantedDataType) {
343  Value castForStore =
344  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
345  args.push_back(castForStore);
346  } else {
347  args.push_back(storeData);
348  }
349  }
350 
351  if (atomicCmpData) {
352  if (llvmBufferValType != llvmWantedDataType) {
353  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
354  loc, llvmBufferValType, atomicCmpData);
355  args.push_back(castForCmp);
356  } else {
357  args.push_back(atomicCmpData);
358  }
359  }
360 
361  // Construct buffer descriptor from memref, attributes
362  int64_t offset = 0;
363  SmallVector<int64_t, 5> strides;
364  if (failed(memrefType.getStridesAndOffset(strides, offset)))
365  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
366 
367  MemRefDescriptor memrefDescriptor(memref);
368 
369  Value ptr = memrefDescriptor.bufferPtr(
370  rewriter, loc, *this->getTypeConverter(), memrefType);
371  Value numRecords = getNumRecords(
372  rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
373  Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
374  adaptor.getBoundsCheck(), chipset);
375  args.push_back(resource);
376 
377  // Indexing (voffset)
378  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
379  adaptor.getIndices(), strides);
380  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
381  indexOffset && *indexOffset > 0) {
382  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
383  voffset =
384  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
385  : extraOffsetConst;
386  }
387  voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
388  args.push_back(voffset);
389 
390  // SGPR offset.
391  Value sgprOffset = adaptor.getSgprOffset();
392  if (!sgprOffset)
393  sgprOffset = createI32Constant(rewriter, loc, 0);
394  sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
395  args.push_back(sgprOffset);
396 
397  // bit 0: GLC = 0 (atomics drop value, less coherency)
398  // bits 1-2: SLC, DLC = 0 (similarly)
399  // bit 3: swizzled (0 for raw)
400  args.push_back(createI32Constant(rewriter, loc, 0));
401 
402  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
403  llvmBufferValType);
404  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
406  if (lowered->getNumResults() == 1) {
407  Value replacement = lowered->getResult(0);
408  if (llvmBufferValType != llvmWantedDataType) {
409  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
410  replacement);
411  }
412  rewriter.replaceOp(gpuOp, replacement);
413  } else {
414  rewriter.eraseOp(gpuOp);
415  }
416  return success();
417  }
418 };
419 
420 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
421  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
422  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
423 
424  Chipset chipset;
425 
426  LogicalResult
427  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
428  ConversionPatternRewriter &rewriter) const override {
429  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
430 
431  if (requiresInlineAsm) {
432  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
433  LLVM::AsmDialect::AD_ATT);
434  const char *asmStr =
435  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
436  const char *constraints = "";
437  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
438  op,
439  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
440  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
441  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
442  /*operand_attrs=*/ArrayAttr());
443  return success();
444  }
445  if (chipset.majorVersion < 12) {
446  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
447  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
448  // Left in place in case someone disables the inline ASM path or future
449  // chipsets use the same bit pattern.
450  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
451 
452  int32_t ldsOnlyBits;
453  if (chipset.majorVersion == 11)
454  ldsOnlyBits = ldsOnlyBitsGfx11;
455  else if (chipset.majorVersion == 10)
456  ldsOnlyBits = ldsOnlyBitsGfx10;
457  else if (chipset.majorVersion <= 9)
458  ldsOnlyBits = ldsOnlyBitsGfx6789;
459  else
460  return op.emitOpError(
461  "don't know how to lower this for chipset major version")
462  << chipset.majorVersion;
463 
464  Location loc = op->getLoc();
465  rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
466  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
467  } else {
468  Location loc = op->getLoc();
469  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
470  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
471  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
472  }
473 
474  return success();
475  }
476 };
477 
478 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
479  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
480  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
481 
482  Chipset chipset;
483 
484  LogicalResult
485  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override {
487  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
488  (uint32_t)op.getOpts());
489  return success();
490  }
491 };
492 
493 } // namespace
494 
495 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
496 /// and LLVM AMDGPU intrinsics convention.
497 ///
498 /// Specifically:
499 /// 1. If the element type is bfloat16, bitcast it to i16.
500 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
501 /// instead, which is what the f8f6f4 intrinsics use.
502 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
503 /// integer.
504 ///
505 /// Note that the type of `input` has already been LLVM type converted:
506 /// therefore 8-bit and smaller floats are represented as their corresponding
507 /// `iN` integers.
509  Location loc, Value input) {
510  Type inputType = input.getType();
511  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
512  if (vectorType.getElementType().isBF16())
513  return rewriter.create<LLVM::BitcastOp>(
514  loc, vectorType.clone(rewriter.getI16Type()), input);
515  if (vectorType.getElementType().isInteger(8) &&
516  vectorType.getNumElements() <= 8)
517  return rewriter.create<LLVM::BitcastOp>(
518  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
519  if (isa<IntegerType>(vectorType.getElementType()) &&
520  vectorType.getElementTypeBitWidth() <= 8) {
521  int64_t numWords = llvm::divideCeil(
522  vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
523  32);
524  return rewriter.create<LLVM::BitcastOp>(
525  loc, VectorType::get(numWords, rewriter.getI32Type()), input);
526  }
527  }
528  return input;
529 }
530 
531 /// Push an input operand. If it is a float type, nothing to do. If it is
532 /// an integer type, then we need to also push its signdness (1 for signed, 0
533 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
534 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
535 /// We also need to convert bfloat inputs to i16 to account for the bfloat
536 /// intrinsics having been defined before the AMD backend supported bfloat. We
537 /// similarly need to pack 8-bit float types into integers as if they were i8
538 /// (which they are for the backend's purposes).
540  Location loc,
541  const TypeConverter *typeConverter,
542  bool isUnsigned, Value llvmInput,
543  Value mlirInput,
544  SmallVector<Value, 4> &operands) {
545  Type inputType = llvmInput.getType();
546  auto vectorType = dyn_cast<VectorType>(inputType);
547  if (!vectorType) {
548  operands.push_back(llvmInput);
549  return;
550  }
551  Type elemType = vectorType.getElementType();
552 
553  if (elemType.isBF16())
554  llvmInput = rewriter.create<LLVM::BitcastOp>(
555  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
556  if (elemType.getIntOrFloatBitWidth() > 8) {
557  operands.push_back(llvmInput);
558  return;
559  }
560 
561  // We need to check the type of the input before conversion to properly test
562  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
563  // fp8/int8 information is lost during the conversion process.
564  auto mlirInputType = cast<VectorType>(mlirInput.getType());
565  bool isInputInteger = mlirInputType.getElementType().isInteger();
566  if (isInputInteger) {
567  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
568  bool localIsUnsigned = isUnsigned;
569  if (elemType.isUnsignedInteger()) {
570  localIsUnsigned = true;
571  } else if (elemType.isSignedInteger()) {
572  localIsUnsigned = false;
573  }
574  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
575  operands.push_back(sign);
576  }
577 
578  int64_t numBits =
579  vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
580  Type i32 = rewriter.getI32Type();
581  Type intrinsicInType = numBits <= 32
582  ? (Type)rewriter.getIntegerType(numBits)
583  : (Type)VectorType::get(numBits / 32, i32);
584  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
585  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
586  loc, llvmIntrinsicInType, llvmInput);
587  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
588  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
589  // Add in the zeros here.
590  if (numBits < 32)
591  castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
592  operands.push_back(castInput);
593 }
594 
595 /// Push the output operand. For many cases this is only pushing the output in
596 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
597 /// since the same numbers of VGPRs is used, we need to decide if to store the
598 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
599 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
600 /// be stored it in the upper part. The subwordOffset must not be set for gfx12,
601 /// as the instructions have been changed to return fewer registers instead.
603  Location loc,
604  const TypeConverter *typeConverter,
605  Value output, int32_t subwordOffset,
606  bool clamp, SmallVector<Value, 4> &operands) {
607  Type inputType = output.getType();
608  auto vectorType = dyn_cast<VectorType>(inputType);
609  Type elemType = vectorType.getElementType();
610  if (elemType.isBF16())
611  output = rewriter.create<LLVM::BitcastOp>(
612  loc, vectorType.clone(rewriter.getI16Type()), output);
613  operands.push_back(output);
614  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
615  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
616  } else if (elemType.isInteger(32)) {
617  operands.push_back(createI1Constant(rewriter, loc, clamp));
618  }
619 }
620 
621 /// Return true if `type` is the E5M2 variant of an 8-bit float that is
622 /// supported by the `_bf8` instructions on the given `chipset`.
623 static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
624  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
625  (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
626 }
627 
628 /// Return true if `type` is the E4M3FN variant of an 8-bit float that is
629 /// supported by the `_fp8` instructions on the given `chipset`.
630 static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
631  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
632  (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
633 }
634 
635 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
636 /// if one exists. This includes checking to ensure the intrinsic is supported
637 /// on the architecture you are compiling for.
638 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
639  Chipset chipset) {
640  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
641  b = mfma.getBlocks();
642  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
643  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
644 
645  if (sourceElem.isF32() && destElem.isF32()) {
646  if (mfma.getReducePrecision() && chipset >= kGfx942) {
647  if (m == 32 && n == 32 && k == 4 && b == 1)
648  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
649  if (m == 16 && n == 16 && k == 8 && b == 1)
650  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
651  }
652  if (m == 32 && n == 32 && k == 1 && b == 2)
653  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
654  if (m == 16 && n == 16 && k == 1 && b == 4)
655  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
656  if (m == 4 && n == 4 && k == 1 && b == 16)
657  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
658  if (m == 32 && n == 32 && k == 2 && b == 1)
659  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
660  if (m == 16 && n == 16 && k == 4 && b == 1)
661  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
662  }
663 
664  if (sourceElem.isF16() && destElem.isF32()) {
665  if (chipset >= kGfx950) {
666  if (m == 32 && n == 32 && k == 16 && b == 1)
667  return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
668  if (m == 16 && n == 16 && k == 32 && b == 1)
669  return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
670  }
671  if (m == 32 && n == 32 && k == 4 && b == 2)
672  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
673  if (m == 16 && n == 16 && k == 4 && b == 4)
674  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
675  if (m == 4 && n == 4 && k == 4 && b == 16)
676  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
677  if (m == 32 && n == 32 && k == 8 && b == 1)
678  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
679  if (m == 16 && n == 16 && k == 16 && b == 1)
680  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
681  }
682 
683  if (sourceElem.isBF16() && destElem.isF32()) {
684  if (chipset >= kGfx950) {
685  if (m == 32 && n == 32 && k == 16 && b == 1)
686  return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
687  if (m == 16 && n == 16 && k == 32 && b == 1)
688  return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
689  }
690  if (chipset >= kGfx90a) {
691  if (m == 32 && n == 32 && k == 4 && b == 2)
692  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
693  if (m == 16 && n == 16 && k == 4 && b == 4)
694  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
695  if (m == 4 && n == 4 && k == 4 && b == 16)
696  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
697  if (m == 32 && n == 32 && k == 8 && b == 1)
698  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
699  if (m == 16 && n == 16 && k == 16 && b == 1)
700  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
701  }
702  if (m == 32 && n == 32 && k == 2 && b == 2)
703  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
704  if (m == 16 && n == 16 && k == 2 && b == 4)
705  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
706  if (m == 4 && n == 4 && k == 2 && b == 16)
707  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
708  if (m == 32 && n == 32 && k == 4 && b == 1)
709  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
710  if (m == 16 && n == 16 && k == 8 && b == 1)
711  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
712  }
713 
714  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
715  if (chipset >= kGfx950) {
716  if (m == 32 && n == 32 && k == 32 && b == 1)
717  return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
718  if (m == 16 && n == 16 && k == 64 && b == 1)
719  return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
720  }
721  if (m == 32 && n == 32 && k == 4 && b == 2)
722  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
723  if (m == 16 && n == 16 && k == 4 && b == 4)
724  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
725  if (m == 4 && n == 4 && k == 4 && b == 16)
726  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
727  if (m == 32 && n == 32 && k == 8 && b == 1)
728  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
729  if (m == 16 && n == 16 && k == 16 && b == 1)
730  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
731  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
732  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
733  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
734  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
735  }
736 
737  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
738  if (m == 16 && n == 16 && k == 4 && b == 1)
739  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
740  if (m == 4 && n == 4 && k == 4 && b == 4)
741  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
742  }
743 
744  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
745  // Known to be correct because there are no scalar f8 instructions and
746  // because a length mismatch will have been caught by the verifier.
747  Type sourceBElem =
748  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
749  if (m == 16 && n == 16 && k == 32 && b == 1) {
750  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
751  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
752  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
753  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
754  }
755  if (m == 32 && n == 32 && k == 16 && b == 1) {
756  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
757  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
758  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
759  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
760  }
761  }
762 
763  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
764  Type sourceBElem =
765  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
766  if (m == 16 && n == 16 && k == 32 && b == 1) {
767  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
768  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
769  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
770  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
771  }
772  if (m == 32 && n == 32 && k == 16 && b == 1) {
773  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
774  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
775  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
776  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
777  }
778  }
779 
780  return std::nullopt;
781 }
782 
783 static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
785  .Case([](Float8E4M3FNType) { return 0u; })
786  .Case([](Float8E5M2Type) { return 1u; })
787  .Case([](Float6E2M3FNType) { return 2u; })
788  .Case([](Float6E3M2FNType) { return 3u; })
789  .Case([](Float4E2M1FNType) { return 4u; })
790  .Default([](Type) { return std::nullopt; });
791 }
792 
793 /// If there is a scaled MFMA instruction for the input element types `aType`
794 /// and `bType`, output type `destType`, problem size M, N, K, and B (number of
795 /// blocks) on the given `chipset`, return a tuple consisting of the
796 /// OperationName of the intrinsic and the type codes that need to be passed to
797 /// that intrinsic. Note that this is also used to implement some un-scaled
798 /// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
799 /// MFMA with a scale of 0.
800 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
801 mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
802  uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
803  aType = getElementTypeOrSelf(aType);
804  bType = getElementTypeOrSelf(bType);
805  destType = getElementTypeOrSelf(destType);
806 
807  if (chipset < kGfx950)
808  return std::nullopt;
809  if (!isa<Float32Type>(destType))
810  return std::nullopt;
811 
812  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
813  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
814  if (!aTypeCode || !bTypeCode)
815  return std::nullopt;
816 
817  if (m == 32 && n == 32 && k == 64 && b == 1)
818  return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
819  *aTypeCode, *bTypeCode};
820  if (m == 16 && n == 16 && k == 128 && b == 1)
821  return std::tuple{
822  ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
823  *bTypeCode};
824 
825  return std::nullopt;
826 }
827 
828 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
829 mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
831  mfma.getSourceA().getType(), mfma.getSourceB().getType(),
832  mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
833  mfma.getBlocks(), chipset);
834 }
835 
836 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
837 /// if one exists. This includes checking to ensure the intrinsic is supported
838 /// on the architecture you are compiling for.
839 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
840  Chipset chipset) {
841  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
842  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
843  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
844  auto elemSourceType = sourceVectorType.getElementType();
845  auto elemBSourceType = sourceBVectorType.getElementType();
846  auto elemDestType = destVectorType.getElementType();
847 
848  if (elemSourceType.isF16() && elemDestType.isF32())
849  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
850  if (elemSourceType.isBF16() && elemDestType.isF32())
851  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
852  if (elemSourceType.isF16() && elemDestType.isF16())
853  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
854  if (elemSourceType.isBF16() && elemDestType.isBF16())
855  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
856  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
857  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
858  if (chipset.majorVersion == 11) {
859  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
860  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
861  }
862  if (chipset.majorVersion >= 12) {
863  if (isa<Float8E4M3FNType>(elemSourceType) &&
864  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
865  return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
866  if (isa<Float8E4M3FNType>(elemSourceType) &&
867  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
868  return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
869  if (isa<Float8E5M2Type>(elemSourceType) &&
870  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
871  return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
872  if (isa<Float8E5M2Type>(elemSourceType) &&
873  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
874  return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
875  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
876  bool isWave64 = destVectorType.getNumElements() == 4;
877  // This is the ambiguous case. 8 inputs to the wave64 version means that
878  // we want the 16x16x32 version, but for wave32 they mean the short form.
879  bool has8Inputs = sourceVectorType.getNumElements() == 8;
880  if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
881  return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
882  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
883  }
884  }
885  return std::nullopt;
886 }
887 
888 namespace {
889 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
890  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
891  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
892 
893  Chipset chipset;
894 
895  LogicalResult
896  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
897  ConversionPatternRewriter &rewriter) const override {
898  Location loc = op.getLoc();
899  Type outType = typeConverter->convertType(op.getDestD().getType());
900  Type intrinsicOutType = outType;
901  if (auto outVecType = dyn_cast<VectorType>(outType))
902  if (outVecType.getElementType().isBF16())
903  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
904 
905  if (chipset.majorVersion != 9 || chipset < kGfx908)
906  return op->emitOpError("MFMA only supported on gfx908+");
907  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
908  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
909  if (chipset < kGfx942)
910  return op.emitOpError("negation unsupported on older than gfx942");
911  getBlgpField |=
912  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
913  }
914  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
915  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
916  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
917  if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
918  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
919 
920  bool isScaled =
921  !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
922  if (isScaled &&
923  (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
924  return op.emitOpError(
925  "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
926  "be scaled as those fields are used for type information");
927  }
928 
929  StringRef intrinsicName =
930  isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
931  OperationState loweredOp(loc, intrinsicName);
932  loweredOp.addTypes(intrinsicOutType);
933  loweredOp.addOperands(
934  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
935  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
936  adaptor.getDestC()});
937  if (isScaled) {
938  Value zero = createI32Constant(rewriter, loc, 0);
939  auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
940  loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
941  createI32Constant(rewriter, loc, bTypeCode),
942  /*scale A byte=*/zero, /*scale A=*/zero,
943  /*scale B byte=*/zero, /*scale B=*/zero});
944  } else {
945  loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
946  createI32Constant(rewriter, loc, op.getAbid()),
947  createI32Constant(rewriter, loc, getBlgpField)});
948  };
949  Value lowered = rewriter.create(loweredOp)->getResult(0);
950  if (outType != intrinsicOutType)
951  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
952  rewriter.replaceOp(op, lowered);
953  return success();
954  }
955 };
956 
957 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
958  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
959  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
960 
961  Chipset chipset;
962 
963  LogicalResult
964  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
965  ConversionPatternRewriter &rewriter) const override {
966  Location loc = op.getLoc();
967  auto outType =
968  typeConverter->convertType<VectorType>(op.getDestD().getType());
969  if (!outType)
970  return rewriter.notifyMatchFailure(op, "type conversion failed");
971 
972  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
973  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
974 
975  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
976  // need to bitcast bfloats to i16 and then bitcast them back.
977  VectorType rawOutType = outType;
978  if (outType.getElementType().isBF16())
979  rawOutType = outType.clone(rewriter.getI16Type());
980 
981  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
982 
983  if (!maybeIntrinsic.has_value())
984  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
985 
986  if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
987  return op.emitOpError("subwordOffset not supported on gfx12+");
988 
989  OperationState loweredOp(loc, *maybeIntrinsic);
990  loweredOp.addTypes(rawOutType);
991 
992  SmallVector<Value, 4> operands;
993  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
994  adaptor.getSourceA(), op.getSourceA(), operands);
995  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
996  adaptor.getSourceB(), op.getSourceB(), operands);
997  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
998  op.getSubwordOffset(), op.getClamp(), operands);
999 
1000  loweredOp.addOperands(operands);
1001  Operation *lowered = rewriter.create(loweredOp);
1002 
1003  Operation *maybeCastBack = lowered;
1004  if (rawOutType != outType)
1005  maybeCastBack =
1006  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
1007  rewriter.replaceOp(op, maybeCastBack->getResults());
1008 
1009  return success();
1010  }
1011 };
1012 
1013 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1014  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1015  : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1016 
1017  Chipset chipset;
1018 
1019  LogicalResult
1020  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1021  ConversionPatternRewriter &rewriter) const override {
1022  if (chipset < kGfx942)
1023  return op.emitOpError("chipset not supported");
1024 
1025  Location loc = op.getLoc();
1026 
1027  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1028  auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
1029 
1030  // TODO: instead of only transfering one element per thread, we could
1031  // augment it to transfer multiple elements per thread by issuing multiple
1032  // `global_load_lds` instructions.
1033  Type transferType = op.getTransferType();
1034  size_t loadWidth = [&]() -> size_t {
1035  if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1036  return transferVectorType.getNumElements() *
1037  (transferVectorType.getElementTypeBitWidth() / 8);
1038  } else {
1039  return transferType.getIntOrFloatBitWidth() / 8;
1040  }
1041  }();
1042 
1043  // Currently only 1, 2, and 4 byte loads are supported.
1044  if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1045  return op.emitOpError("chipset unsupported element size");
1046 
1047  Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
1048  (adaptor.getSrcIndices()), rewriter);
1049  Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
1050  (adaptor.getDstIndices()), rewriter);
1051 
1052  rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
1053  op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
1054  createI32Constant(rewriter, loc, 0),
1055  createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{},
1056  ArrayAttr{});
1057 
1058  return success();
1059  }
1060 };
1061 
1062 namespace {
1063 struct ExtPackedFp8OpLowering final
1064  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1065  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1066  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1067  chipset(chipset) {}
1068  Chipset chipset;
1069 
1070  LogicalResult
1071  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1072  ConversionPatternRewriter &rewriter) const override;
1073 };
1074 
1075 struct PackedTrunc2xFp8OpLowering final
1076  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1077  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1078  Chipset chipset)
1079  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1080  chipset(chipset) {}
1081  Chipset chipset;
1082 
1083  LogicalResult
1084  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1085  ConversionPatternRewriter &rewriter) const override;
1086 };
1087 
1088 struct PackedStochRoundFp8OpLowering final
1089  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1090  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1091  Chipset chipset)
1092  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1093  chipset(chipset) {}
1094  Chipset chipset;
1095 
1096  LogicalResult
1097  matchAndRewrite(PackedStochRoundFp8Op op,
1098  PackedStochRoundFp8OpAdaptor adaptor,
1099  ConversionPatternRewriter &rewriter) const override;
1100 };
1101 } // end namespace
1102 
1103 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1104  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1105  ConversionPatternRewriter &rewriter) const {
1106  Location loc = op.getLoc();
1107  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1108  return rewriter.notifyMatchFailure(
1109  loc, "Fp8 conversion instructions are not available on target "
1110  "architecture and their emulation is not implemented");
1111  Type v4i8 =
1112  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1113  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1114  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1115 
1116  Value source = adaptor.getSource();
1117  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1118  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1119  Type sourceElemType = getElementTypeOrSelf(op.getSource());
1120  // Extend to a v4i8
1121  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1122  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
1123  if (!sourceVecType) {
1124  longVec = rewriter.create<LLVM::InsertElementOp>(
1125  loc, longVec, source, createI32Constant(rewriter, loc, 0));
1126  } else {
1127  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1128  Value idx = createI32Constant(rewriter, loc, i);
1129  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1130  longVec =
1131  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1132  }
1133  }
1134  source = longVec;
1135  }
1136  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1137  if (resultVecType) {
1138  Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
1139  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1140  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1141  wordSel);
1142  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1143  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1144  wordSel);
1145  }
1146  } else {
1147  Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
1148  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1149  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1150  byteSel);
1151  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1152  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1153  byteSel);
1154  }
1155  }
1156  return success();
1157 }
1158 
1159 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1160  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1161  ConversionPatternRewriter &rewriter) const {
1162  Location loc = op.getLoc();
1163  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1164  return rewriter.notifyMatchFailure(
1165  loc, "Fp8 conversion instructions are not available on target "
1166  "architecture and their emulation is not implemented");
1167  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1168 
1169  Type resultType = op.getResult().getType();
1170  Type resultElemType = getElementTypeOrSelf(resultType);
1171 
1172  Value sourceA = adaptor.getSourceA();
1173  Value sourceB = adaptor.getSourceB();
1174  if (!sourceB)
1175  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
1176  Value existing = adaptor.getExisting();
1177  if (existing)
1178  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1179  else
1180  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1181  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
1182 
1183  Value result;
1184  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1185  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1186  existing, wordSel);
1187  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1188  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1189  existing, wordSel);
1190 
1191  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1192  op, getTypeConverter()->convertType(resultType), result);
1193  return success();
1194 }
1195 
1196 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1197  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1198  ConversionPatternRewriter &rewriter) const {
1199  Location loc = op.getLoc();
1200  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1201  return rewriter.notifyMatchFailure(
1202  loc, "Fp8 conversion instructions are not available on target "
1203  "architecture and their emulation is not implemented");
1204  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1205 
1206  Type resultType = op.getResult().getType();
1207  Type resultElemType = getElementTypeOrSelf(resultType);
1208 
1209  Value source = adaptor.getSource();
1210  Value stoch = adaptor.getStochiasticParam();
1211  Value existing = adaptor.getExisting();
1212  if (existing)
1213  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1214  else
1215  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1216  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
1217 
1218  Value result;
1219  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1220  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
1221  existing, byteSel);
1222  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1223  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
1224  existing, byteSel);
1225 
1226  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1227  op, getTypeConverter()->convertType(resultType), result);
1228  return success();
1229 }
1230 
1231 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1232 // operation into the corresponding ROCDL instructions.
1233 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1234  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1235  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1236  Chipset chipset;
1237 
1238  LogicalResult
1239  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1240  ConversionPatternRewriter &rewriter) const override {
1241 
1242  // Convert the source operand to the corresponding LLVM type
1243  Location loc = DppOp.getLoc();
1244  Value src = adaptor.getSrc();
1245  Value old = adaptor.getOld();
1246  Type srcType = src.getType();
1247  Type oldType = old.getType();
1248  Type llvmType = nullptr;
1249  if (srcType.getIntOrFloatBitWidth() < 32) {
1250  llvmType = rewriter.getI32Type();
1251  } else if (isa<FloatType>(srcType)) {
1252  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1253  ? rewriter.getF32Type()
1254  : rewriter.getF64Type();
1255  } else if (isa<IntegerType>(srcType)) {
1256  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1257  ? rewriter.getI32Type()
1258  : rewriter.getI64Type();
1259  }
1260  auto llvmSrcIntType = typeConverter->convertType(
1261  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1262 
1263  // If the source type is less of 32, use bitcast to convert it to i32.
1264  auto convertOperand = [&](Value operand, Type operandType) {
1265  if (operandType.getIntOrFloatBitWidth() <= 16) {
1266  if (llvm::isa<FloatType>(operandType)) {
1267  operand =
1268  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1269  }
1270  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1271  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1272  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
1273  operand = rewriter.create<LLVM::InsertElementOp>(
1274  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
1275  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
1276  }
1277  return operand;
1278  };
1279 
1280  src = convertOperand(src, srcType);
1281  old = convertOperand(old, oldType);
1282 
1283  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1284  enum DppCtrl : unsigned {
1285  ROW_SHL0 = 0x100,
1286  ROW_SHR0 = 0x110,
1287  ROW_ROR0 = 0x120,
1288  WAVE_SHL1 = 0x130,
1289  WAVE_ROL1 = 0x134,
1290  WAVE_SHR1 = 0x138,
1291  WAVE_ROR1 = 0x13C,
1292  ROW_MIRROR = 0x140,
1293  ROW_HALF_MIRROR = 0x141,
1294  BCAST15 = 0x142,
1295  BCAST31 = 0x143,
1296  };
1297 
1298  auto kind = DppOp.getKind();
1299  auto permArgument = DppOp.getPermArgument();
1300  uint32_t DppCtrl = 0;
1301 
1302  switch (kind) {
1303 
1304  case DPPPerm::quad_perm:
1305  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1306  int32_t i = 0;
1307  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1308  uint32_t num = elem.getInt();
1309  DppCtrl |= num << (i * 2);
1310  i++;
1311  }
1312  }
1313  break;
1314  case DPPPerm::row_shl:
1315  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1316  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1317  }
1318  break;
1319  case DPPPerm::row_shr:
1320  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1321  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1322  }
1323  break;
1324  case DPPPerm::row_ror:
1325  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1326  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1327  }
1328  break;
1329  case DPPPerm::wave_shl:
1330  DppCtrl = DppCtrl::WAVE_SHL1;
1331  break;
1332  case DPPPerm::wave_shr:
1333  DppCtrl = DppCtrl::WAVE_SHR1;
1334  break;
1335  case DPPPerm::wave_rol:
1336  DppCtrl = DppCtrl::WAVE_ROL1;
1337  break;
1338  case DPPPerm::wave_ror:
1339  DppCtrl = DppCtrl::WAVE_ROR1;
1340  break;
1341  case DPPPerm::row_mirror:
1342  DppCtrl = DppCtrl::ROW_MIRROR;
1343  break;
1344  case DPPPerm::row_half_mirror:
1345  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1346  break;
1347  case DPPPerm::row_bcast_15:
1348  DppCtrl = DppCtrl::BCAST15;
1349  break;
1350  case DPPPerm::row_bcast_31:
1351  DppCtrl = DppCtrl::BCAST31;
1352  break;
1353  }
1354 
1355  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1356  // constants
1357  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1358  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1359  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1360 
1361  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1362  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1363  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1364 
1365  Value result = dppMovOp.getRes();
1366  if (srcType.getIntOrFloatBitWidth() < 32) {
1367  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1368  if (!llvm::isa<IntegerType>(srcType)) {
1369  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1370  }
1371  }
1372 
1373  // We are replacing the AMDGPU_DPPOp instruction with the new
1374  // ROCDL_DPPMovOp instruction
1375  rewriter.replaceOp(DppOp, ValueRange(result));
1376  return success();
1377  }
1378 };
1379 
1380 struct ConvertAMDGPUToROCDLPass
1381  : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1382  using Base::Base;
1383 
1384  void runOnOperation() override {
1385  MLIRContext *ctx = &getContext();
1386  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1387  if (failed(maybeChipset)) {
1388  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1389  return signalPassFailure();
1390  }
1391 
1393  LLVMTypeConverter converter(ctx);
1394  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1395  LLVMConversionTarget target(getContext());
1396  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1397  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1398  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1399  if (failed(applyPartialConversion(getOperation(), target,
1400  std::move(patterns))))
1401  signalPassFailure();
1402  }
1403 };
1404 } // namespace
1405 
1407  TypeConverter &typeConverter) {
1408  typeConverter.addTypeAttributeConversion(
1409  [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1411  MLIRContext *ctx = as.getContext();
1412  Type i64 = IntegerType::get(ctx, 64);
1413  switch (as.getValue()) {
1414  case amdgpu::AddressSpace::FatRawBuffer:
1415  return IntegerAttr::get(i64, 7);
1416  case amdgpu::AddressSpace::BufferRsrc:
1417  return IntegerAttr::get(i64, 8);
1418  case amdgpu::AddressSpace::FatStructuredBuffer:
1419  return IntegerAttr::get(i64, 9);
1420  }
1422  });
1423 }
1424 
1427  Chipset chipset) {
1429  patterns
1430  .add<FatRawBufferCastLowering,
1431  RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1432  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1433  RawBufferOpLowering<RawBufferAtomicFaddOp,
1434  ROCDL::RawPtrBufferAtomicFaddOp>,
1435  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1436  ROCDL::RawPtrBufferAtomicFmaxOp>,
1437  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1438  ROCDL::RawPtrBufferAtomicSmaxOp>,
1439  RawBufferOpLowering<RawBufferAtomicUminOp,
1440  ROCDL::RawPtrBufferAtomicUminOp>,
1441  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1442  ROCDL::RawPtrBufferAtomicCmpSwap>,
1443  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1444  MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1445  PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1446  GatherToLDSOpLowering>(converter, chipset);
1447 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI16Type()
Definition: Builders.cpp:61
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:213
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IntegerType getI8Type()
Definition: Builders.cpp:59
FloatType getF64Type()
Definition: Builders.cpp:45
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
The general result of a type attribute conversion callback, allowing for early termination.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14