MLIR  21.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "../LLVMCommon/MemRefDescriptor.h"
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include <optional>
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::amdgpu;
37 
38 // Define commonly used chipsets versions for convenience.
39 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
40 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
41 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
42 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
43 
44 /// Convert an unsigned number `val` to i32.
46  Location loc, Value val) {
47  IntegerType i32 = rewriter.getI32Type();
48  // Force check that `val` is of int type.
49  auto valTy = cast<IntegerType>(val.getType());
50  if (i32 == valTy)
51  return val;
52  return valTy.getWidth() > 32
53  ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
54  : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
55 }
56 
58  Location loc, int32_t value) {
59  Type i32 = rewriter.getI32Type();
60  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
61 }
62 
64  bool value) {
65  Type llvmI1 = rewriter.getI1Type();
66  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
67 }
68 
69 /// Returns the linear index used to access an element in the memref.
71  Location loc, MemRefDescriptor &memRefDescriptor,
72  ValueRange indices, ArrayRef<int64_t> strides) {
73  IntegerType i32 = rewriter.getI32Type();
74  Value index;
75  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
76  if (stride != 1) { // Skip if stride is 1.
77  Value strideValue =
78  ShapedType::isDynamic(stride)
79  ? convertUnsignedToI32(rewriter, loc,
80  memRefDescriptor.stride(rewriter, loc, i))
81  : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
82  increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
83  }
84  index =
85  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
86  }
87  return index ? index : createI32Constant(rewriter, loc, 0);
88 }
89 
90 /// Compute the contents of the `num_records` field for a given memref
91 /// descriptor - that is, the number of bytes that's one element past the
92 /// greatest possible valid index into the memref.
94  MemRefType memrefType,
95  MemRefDescriptor &memrefDescriptor,
96  ArrayRef<int64_t> strides,
97  uint32_t elementByteWidth) {
98  if (memrefType.hasStaticShape() &&
99  !llvm::any_of(strides, ShapedType::isDynamic)) {
100  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
101  ArrayRef<int64_t> shape = memrefType.getShape();
102  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
103  size = std::max(shape[i] * strides[i], size);
104  size = size * elementByteWidth;
105  assert(size < std::numeric_limits<uint32_t>::max() &&
106  "the memref buffer is too large");
107  return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
108  }
109  Value maxIndex;
110  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
111  Value size = memrefDescriptor.size(rewriter, loc, i);
112  Value stride = memrefDescriptor.stride(rewriter, loc, i);
113  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
114  maxIndex = maxIndex
115  ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
116  : maxThisDim;
117  }
118  Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
119  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
120  return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
121 }
122 
124  Value basePointer, Value numRecords,
125  bool boundsCheck, amdgpu::Chipset chipset,
126  Value cacheSwizzleStride = nullptr,
127  unsigned addressSpace = 8) {
128  // The stride value is generally 0. However, on MI-300 and onward, you can
129  // enable a cache swizzling mode by setting bit 14 of the stride field
130  // and setting that stride to a cache stride.
131  Type i16 = rewriter.getI16Type();
132  Value stride;
133  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
134  Value cacheStrideZext =
135  rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
136  Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
137  loc, i16, rewriter.getI16IntegerAttr(1 << 14));
138  stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
139  /*isDisjoint=*/true);
140  } else {
141  stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
142  rewriter.getI16IntegerAttr(0));
143  }
144  // Get the number of elements.
145  // Flag word:
146  // bits 0-11: dst sel, ignored by these intrinsics
147  // bits 12-14: data format (ignored, must be nonzero, 7=float)
148  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
149  // bit 19: In nested heap (0 here)
150  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
151  // bits 21-22: Index stride for swizzles (N/A)
152  // bit 23: Add thread ID (0)
153  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
154  // bits 25-26: Reserved (0)
155  // bit 27: Buffer is non-volatile (CDNA only)
156  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
157  // none, 3 = either swizzles or testing against offset field) RDNA only
158  // bits 30-31: Type (must be 0)
159  uint32_t flags = (7 << 12) | (4 << 15);
160  if (chipset.majorVersion >= 10) {
161  flags |= (1 << 24);
162  uint32_t oob = boundsCheck ? 3 : 2;
163  flags |= (oob << 28);
164  }
165  Value flagsConst = createI32Constant(rewriter, loc, flags);
166  Type rsrcType =
167  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
168  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
169  loc, rsrcType, basePointer, stride, numRecords, flagsConst);
170  return resource;
171 }
172 
173 namespace {
174 struct FatRawBufferCastLowering
175  : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
176  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
177  : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
178  chipset(chipset) {}
179 
180  Chipset chipset;
181 
182  LogicalResult
183  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const override {
185  Location loc = op.getLoc();
186  Value memRef = adaptor.getSource();
187  Value unconvertedMemref = op.getSource();
188  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
189  MemRefDescriptor descriptor(memRef);
190 
191  DataLayout dataLayout = DataLayout::closest(op);
192  int64_t elementByteWidth =
193  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
194 
195  int64_t unusedOffset = 0;
196  SmallVector<int64_t, 5> strideVals;
197  if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
198  return op.emitOpError("Can't lower non-stride-offset memrefs");
199 
200  Value numRecords = adaptor.getValidBytes();
201  if (!numRecords)
202  numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
203  strideVals, elementByteWidth);
204 
205  Value basePointer =
206  adaptor.getResetOffset()
207  ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
208  memrefType)
209  : descriptor.alignedPtr(rewriter, loc);
210 
211  Value offset = adaptor.getResetOffset()
212  ? rewriter.create<LLVM::ConstantOp>(
213  loc, getIndexType(), rewriter.getIndexAttr(0))
214  : descriptor.offset(rewriter, loc);
215 
216  bool hasSizes = memrefType.getRank() > 0;
217  // No need to unpack() and pack() all the individual sizes and strides,
218  // so we'll just extract the arrays.
219  Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
220  loc, descriptor, kSizePosInMemRefDescriptor)
221  : Value{};
222  Value strides = hasSizes
223  ? rewriter.create<LLVM::ExtractValueOp>(
224  loc, descriptor, kStridePosInMemRefDescriptor)
225  : Value{};
226 
227  Value fatPtr = makeBufferRsrc(
228  rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
229  chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
230 
232  rewriter, loc,
233  getTypeConverter()->convertType(op.getResult().getType()));
234  result = rewriter.create<LLVM::InsertValueOp>(
235  loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
236  result = rewriter.create<LLVM::InsertValueOp>(
237  loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
238  result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
240  if (hasSizes) {
241  result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
243  result = rewriter.create<LLVM::InsertValueOp>(
244  loc, result, strides, kStridePosInMemRefDescriptor);
245  }
246  rewriter.replaceOp(op, result);
247  return success();
248  }
249 };
250 
251 /// Define lowering patterns for raw buffer ops
252 template <typename GpuOp, typename Intrinsic>
253 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
254  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
255  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
256 
257  Chipset chipset;
258  static constexpr uint32_t maxVectorOpWidth = 128;
259 
260  LogicalResult
261  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
262  ConversionPatternRewriter &rewriter) const override {
263  Location loc = gpuOp.getLoc();
264  Value memref = adaptor.getMemref();
265  Value unconvertedMemref = gpuOp.getMemref();
266  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
267 
268  if (chipset.majorVersion < 9)
269  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
270 
271  Value storeData = adaptor.getODSOperands(0)[0];
272  if (storeData == memref) // no write component to this op
273  storeData = Value();
274  Type wantedDataType;
275  if (storeData)
276  wantedDataType = storeData.getType();
277  else
278  wantedDataType = gpuOp.getODSResults(0)[0].getType();
279 
280  Value atomicCmpData = Value();
281  // Operand index 1 of a load is the indices, trying to read them can crash.
282  if (storeData) {
283  Value maybeCmpData = adaptor.getODSOperands(1)[0];
284  if (maybeCmpData != memref)
285  atomicCmpData = maybeCmpData;
286  }
287 
288  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
289 
290  Type i32 = rewriter.getI32Type();
291 
292  // Get the type size in bytes.
293  DataLayout dataLayout = DataLayout::closest(gpuOp);
294  int64_t elementByteWidth =
295  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
296  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
297 
298  // If we want to load a vector<NxT> with total size <= 32
299  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
300  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
301  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
302  // so bitcast any floats to integers.
303  Type llvmBufferValType = llvmWantedDataType;
304  if (atomicCmpData) {
305  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
306  llvmBufferValType = this->getTypeConverter()->convertType(
307  rewriter.getIntegerType(floatType.getWidth()));
308  }
309  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
310  uint32_t vecLen = dataVector.getNumElements();
311  uint32_t elemBits =
312  dataLayout.getTypeSizeInBits(dataVector.getElementType());
313  uint32_t totalBits = elemBits * vecLen;
314  bool usePackedFp16 =
315  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
316  if (totalBits > maxVectorOpWidth)
317  return gpuOp.emitOpError(
318  "Total width of loads or stores must be no more than " +
319  Twine(maxVectorOpWidth) + " bits, but we call for " +
320  Twine(totalBits) +
321  " bits. This should've been caught in validation");
322  if (!usePackedFp16 && elemBits < 32) {
323  if (totalBits > 32) {
324  if (totalBits % 32 != 0)
325  return gpuOp.emitOpError("Load or store of more than 32-bits that "
326  "doesn't fit into words. Can't happen\n");
327  llvmBufferValType = this->typeConverter->convertType(
328  VectorType::get(totalBits / 32, i32));
329  } else {
330  llvmBufferValType = this->typeConverter->convertType(
331  rewriter.getIntegerType(totalBits));
332  }
333  }
334  }
335  if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
336  // Buffer intrinsics doesn't support 1-element vectors, cast them to
337  // scalars.
338  if (vecType.getNumElements() == 1)
339  llvmBufferValType = vecType.getElementType();
340  }
341 
343  if (storeData) {
344  if (llvmBufferValType != llvmWantedDataType) {
345  Value castForStore =
346  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
347  args.push_back(castForStore);
348  } else {
349  args.push_back(storeData);
350  }
351  }
352 
353  if (atomicCmpData) {
354  if (llvmBufferValType != llvmWantedDataType) {
355  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
356  loc, llvmBufferValType, atomicCmpData);
357  args.push_back(castForCmp);
358  } else {
359  args.push_back(atomicCmpData);
360  }
361  }
362 
363  // Construct buffer descriptor from memref, attributes
364  int64_t offset = 0;
365  SmallVector<int64_t, 5> strides;
366  if (failed(memrefType.getStridesAndOffset(strides, offset)))
367  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
368 
369  MemRefDescriptor memrefDescriptor(memref);
370 
371  Value ptr = memrefDescriptor.bufferPtr(
372  rewriter, loc, *this->getTypeConverter(), memrefType);
373  Value numRecords = getNumRecords(
374  rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
375  Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
376  adaptor.getBoundsCheck(), chipset);
377  args.push_back(resource);
378 
379  // Indexing (voffset)
380  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
381  adaptor.getIndices(), strides);
382  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
383  indexOffset && *indexOffset > 0) {
384  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
385  voffset =
386  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
387  : extraOffsetConst;
388  }
389  voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
390  args.push_back(voffset);
391 
392  // SGPR offset.
393  Value sgprOffset = adaptor.getSgprOffset();
394  if (!sgprOffset)
395  sgprOffset = createI32Constant(rewriter, loc, 0);
396  sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
397  args.push_back(sgprOffset);
398 
399  // bit 0: GLC = 0 (atomics drop value, less coherency)
400  // bits 1-2: SLC, DLC = 0 (similarly)
401  // bit 3: swizzled (0 for raw)
402  args.push_back(createI32Constant(rewriter, loc, 0));
403 
404  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
405  llvmBufferValType);
406  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
408  if (lowered->getNumResults() == 1) {
409  Value replacement = lowered->getResult(0);
410  if (llvmBufferValType != llvmWantedDataType) {
411  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
412  replacement);
413  }
414  rewriter.replaceOp(gpuOp, replacement);
415  } else {
416  rewriter.eraseOp(gpuOp);
417  }
418  return success();
419  }
420 };
421 
422 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
423  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
424  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
425 
426  Chipset chipset;
427 
428  LogicalResult
429  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
430  ConversionPatternRewriter &rewriter) const override {
431  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
432 
433  if (requiresInlineAsm) {
434  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
435  LLVM::AsmDialect::AD_ATT);
436  const char *asmStr =
437  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
438  const char *constraints = "";
439  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
440  op,
441  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
442  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
443  /*is_align_stack=*/false, LLVM::TailCallKind::None,
444  /*asm_dialect=*/asmDialectAttr,
445  /*operand_attrs=*/ArrayAttr());
446  return success();
447  }
448  if (chipset.majorVersion < 12) {
449  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
450  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
451  // Left in place in case someone disables the inline ASM path or future
452  // chipsets use the same bit pattern.
453  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
454 
455  int32_t ldsOnlyBits;
456  if (chipset.majorVersion == 11)
457  ldsOnlyBits = ldsOnlyBitsGfx11;
458  else if (chipset.majorVersion == 10)
459  ldsOnlyBits = ldsOnlyBitsGfx10;
460  else if (chipset.majorVersion <= 9)
461  ldsOnlyBits = ldsOnlyBitsGfx6789;
462  else
463  return op.emitOpError(
464  "don't know how to lower this for chipset major version")
465  << chipset.majorVersion;
466 
467  Location loc = op->getLoc();
468  rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
469  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
470  } else {
471  Location loc = op->getLoc();
472  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
473  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
474  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
475  }
476 
477  return success();
478  }
479 };
480 
481 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
482  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
483  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
484 
485  Chipset chipset;
486 
487  LogicalResult
488  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
489  ConversionPatternRewriter &rewriter) const override {
490  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
491  (uint32_t)op.getOpts());
492  return success();
493  }
494 };
495 
496 } // namespace
497 
498 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
499 /// and LLVM AMDGPU intrinsics convention.
500 ///
501 /// Specifically:
502 /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503 /// allows bf16. Newer MFMAs support bf16 types on operand, check
504 /// IntrinsicsAMDGPU.td file for reference.
505 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
506 /// instead, which is what the f8f6f4 intrinsics use.
507 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
508 /// integer.
509 ///
510 /// Note that the type of `input` has already been LLVM type converted:
511 /// therefore 8-bit and smaller floats are represented as their corresponding
512 /// `iN` integers.
514  Location loc, Value input,
515  bool allowBf16 = true) {
516  Type inputType = input.getType();
517  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
518  if (vectorType.getElementType().isBF16() && !allowBf16)
519  return rewriter.create<LLVM::BitcastOp>(
520  loc, vectorType.clone(rewriter.getI16Type()), input);
521  if (vectorType.getElementType().isInteger(8) &&
522  vectorType.getNumElements() <= 8)
523  return rewriter.create<LLVM::BitcastOp>(
524  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
525  if (isa<IntegerType>(vectorType.getElementType()) &&
526  vectorType.getElementTypeBitWidth() <= 8) {
527  int64_t numWords = llvm::divideCeil(
528  vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
529  32);
530  return rewriter.create<LLVM::BitcastOp>(
531  loc, VectorType::get(numWords, rewriter.getI32Type()), input);
532  }
533  }
534  return input;
535 }
536 
537 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
538 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
539 ///
540 /// Specifically:
541 /// 1. If `input` is a i8 value, zero extend it to i32
542 /// 2. If `input` is a vector of length 4 and type i8, cast it to i32
543 ///
544 /// Note that the type of `input` has already been LLVM type converted:
545 /// therefore 8-bit and smaller floats are represented as their corresponding
546 /// `iN` integers.
548  Location loc, Value input) {
549  Type inputType = input.getType();
550  Type outputType = rewriter.getI32Type();
551  if (auto intType = dyn_cast<IntegerType>(inputType))
552  return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
553  return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
554 }
555 
556 /// Push an input operand. If it is a float type, nothing to do. If it is
557 /// an integer type, then we need to also push its signdness (1 for signed, 0
558 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
559 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
560 /// We also need to convert bfloat inputs to i16 to account for the bfloat
561 /// intrinsics having been defined before the AMD backend supported bfloat. We
562 /// similarly need to pack 8-bit float types into integers as if they were i8
563 /// (which they are for the backend's purposes).
565  Location loc,
566  const TypeConverter *typeConverter,
567  bool isUnsigned, Value llvmInput,
568  Value mlirInput,
569  SmallVector<Value, 4> &operands) {
570  Type inputType = llvmInput.getType();
571  auto vectorType = dyn_cast<VectorType>(inputType);
572  if (!vectorType) {
573  operands.push_back(llvmInput);
574  return;
575  }
576  Type elemType = vectorType.getElementType();
577 
578  if (elemType.isBF16())
579  llvmInput = rewriter.create<LLVM::BitcastOp>(
580  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
581  if (elemType.getIntOrFloatBitWidth() > 8) {
582  operands.push_back(llvmInput);
583  return;
584  }
585 
586  // We need to check the type of the input before conversion to properly test
587  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
588  // fp8/int8 information is lost during the conversion process.
589  auto mlirInputType = cast<VectorType>(mlirInput.getType());
590  bool isInputInteger = mlirInputType.getElementType().isInteger();
591  if (isInputInteger) {
592  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
593  bool localIsUnsigned = isUnsigned;
594  if (elemType.isUnsignedInteger()) {
595  localIsUnsigned = true;
596  } else if (elemType.isSignedInteger()) {
597  localIsUnsigned = false;
598  }
599  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
600  operands.push_back(sign);
601  }
602 
603  int64_t numBits =
604  vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
605  Type i32 = rewriter.getI32Type();
606  Type intrinsicInType = numBits <= 32
607  ? (Type)rewriter.getIntegerType(numBits)
608  : (Type)VectorType::get(numBits / 32, i32);
609  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
610  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
611  loc, llvmIntrinsicInType, llvmInput);
612  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
613  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
614  // Add in the zeros here.
615  if (numBits < 32)
616  castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
617  operands.push_back(castInput);
618 }
619 
620 /// Push the output operand. For many cases this is only pushing the output in
621 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
622 /// since the same numbers of VGPRs is used, we need to decide if to store the
623 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
624 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
625 /// be stored it in the upper part. The subwordOffset must not be set for gfx12,
626 /// as the instructions have been changed to return fewer registers instead.
628  Location loc,
629  const TypeConverter *typeConverter,
630  Value output, int32_t subwordOffset,
631  bool clamp, SmallVector<Value, 4> &operands) {
632  Type inputType = output.getType();
633  auto vectorType = dyn_cast<VectorType>(inputType);
634  Type elemType = vectorType.getElementType();
635  if (elemType.isBF16())
636  output = rewriter.create<LLVM::BitcastOp>(
637  loc, vectorType.clone(rewriter.getI16Type()), output);
638  operands.push_back(output);
639  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
640  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
641  } else if (elemType.isInteger(32)) {
642  operands.push_back(createI1Constant(rewriter, loc, clamp));
643  }
644 }
645 
646 /// Return true if `type` is the E5M2 variant of an 8-bit float that is
647 /// supported by the `_bf8` instructions on the given `chipset`.
648 static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
649  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
650  (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
651 }
652 
653 /// Return true if `type` is the E4M3FN variant of an 8-bit float that is
654 /// supported by the `_fp8` instructions on the given `chipset`.
655 static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
656  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
657  (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
658 }
659 
660 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
661 /// if one exists. This includes checking to ensure the intrinsic is supported
662 /// on the architecture you are compiling for.
663 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
664  Chipset chipset) {
665  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
666  b = mfma.getBlocks();
667  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
668  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
669 
670  if (sourceElem.isF32() && destElem.isF32()) {
671  if (mfma.getReducePrecision() && chipset >= kGfx942) {
672  if (m == 32 && n == 32 && k == 4 && b == 1)
673  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
674  if (m == 16 && n == 16 && k == 8 && b == 1)
675  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
676  }
677  if (m == 32 && n == 32 && k == 1 && b == 2)
678  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
679  if (m == 16 && n == 16 && k == 1 && b == 4)
680  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
681  if (m == 4 && n == 4 && k == 1 && b == 16)
682  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
683  if (m == 32 && n == 32 && k == 2 && b == 1)
684  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
685  if (m == 16 && n == 16 && k == 4 && b == 1)
686  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
687  }
688 
689  if (sourceElem.isF16() && destElem.isF32()) {
690  if (chipset >= kGfx950) {
691  if (m == 32 && n == 32 && k == 16 && b == 1)
692  return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
693  if (m == 16 && n == 16 && k == 32 && b == 1)
694  return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
695  }
696  if (m == 32 && n == 32 && k == 4 && b == 2)
697  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
698  if (m == 16 && n == 16 && k == 4 && b == 4)
699  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
700  if (m == 4 && n == 4 && k == 4 && b == 16)
701  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
702  if (m == 32 && n == 32 && k == 8 && b == 1)
703  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
704  if (m == 16 && n == 16 && k == 16 && b == 1)
705  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
706  }
707 
708  if (sourceElem.isBF16() && destElem.isF32()) {
709  if (chipset >= kGfx950) {
710  if (m == 32 && n == 32 && k == 16 && b == 1)
711  return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
712  if (m == 16 && n == 16 && k == 32 && b == 1)
713  return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
714  }
715  if (chipset >= kGfx90a) {
716  if (m == 32 && n == 32 && k == 4 && b == 2)
717  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
718  if (m == 16 && n == 16 && k == 4 && b == 4)
719  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
720  if (m == 4 && n == 4 && k == 4 && b == 16)
721  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
722  if (m == 32 && n == 32 && k == 8 && b == 1)
723  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
724  if (m == 16 && n == 16 && k == 16 && b == 1)
725  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
726  }
727  if (m == 32 && n == 32 && k == 2 && b == 2)
728  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
729  if (m == 16 && n == 16 && k == 2 && b == 4)
730  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
731  if (m == 4 && n == 4 && k == 2 && b == 16)
732  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
733  if (m == 32 && n == 32 && k == 4 && b == 1)
734  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
735  if (m == 16 && n == 16 && k == 8 && b == 1)
736  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
737  }
738 
739  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
740  if (chipset >= kGfx950) {
741  if (m == 32 && n == 32 && k == 32 && b == 1)
742  return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
743  if (m == 16 && n == 16 && k == 64 && b == 1)
744  return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
745  }
746  if (m == 32 && n == 32 && k == 4 && b == 2)
747  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
748  if (m == 16 && n == 16 && k == 4 && b == 4)
749  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
750  if (m == 4 && n == 4 && k == 4 && b == 16)
751  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
752  if (m == 32 && n == 32 && k == 8 && b == 1)
753  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
754  if (m == 16 && n == 16 && k == 16 && b == 1)
755  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
756  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
757  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
758  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
759  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
760  }
761 
762  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
763  if (m == 16 && n == 16 && k == 4 && b == 1)
764  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
765  if (m == 4 && n == 4 && k == 4 && b == 4)
766  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
767  }
768 
769  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
770  // Known to be correct because there are no scalar f8 instructions and
771  // because a length mismatch will have been caught by the verifier.
772  Type sourceBElem =
773  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
774  if (m == 16 && n == 16 && k == 32 && b == 1) {
775  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
776  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
777  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
778  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
779  }
780  if (m == 32 && n == 32 && k == 16 && b == 1) {
781  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
782  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
783  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
784  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
785  }
786  }
787 
788  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
789  Type sourceBElem =
790  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
791  if (m == 16 && n == 16 && k == 32 && b == 1) {
792  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
793  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
794  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
795  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
796  }
797  if (m == 32 && n == 32 && k == 16 && b == 1) {
798  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
799  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
800  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
801  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
802  }
803  }
804 
805  return std::nullopt;
806 }
807 
808 static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
810  .Case([](Float8E4M3FNType) { return 0u; })
811  .Case([](Float8E5M2Type) { return 1u; })
812  .Case([](Float6E2M3FNType) { return 2u; })
813  .Case([](Float6E3M2FNType) { return 3u; })
814  .Case([](Float4E2M1FNType) { return 4u; })
815  .Default([](Type) { return std::nullopt; });
816 }
817 
818 /// If there is a scaled MFMA instruction for the input element types `aType`
819 /// and `bType`, output type `destType`, problem size M, N, K, and B (number of
820 /// blocks) on the given `chipset`, return a tuple consisting of the
821 /// OperationName of the intrinsic and the type codes that need to be passed to
822 /// that intrinsic. Note that this is also used to implement some un-scaled
823 /// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
824 /// MFMA with a scale of 0.
825 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
826 mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
827  uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
828  aType = getElementTypeOrSelf(aType);
829  bType = getElementTypeOrSelf(bType);
830  destType = getElementTypeOrSelf(destType);
831 
832  if (chipset < kGfx950)
833  return std::nullopt;
834  if (!isa<Float32Type>(destType))
835  return std::nullopt;
836 
837  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
838  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
839  if (!aTypeCode || !bTypeCode)
840  return std::nullopt;
841 
842  if (m == 32 && n == 32 && k == 64 && b == 1)
843  return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
844  *aTypeCode, *bTypeCode};
845  if (m == 16 && n == 16 && k == 128 && b == 1)
846  return std::tuple{
847  ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
848  *bTypeCode};
849 
850  return std::nullopt;
851 }
852 
853 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
854 mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
856  mfma.getSourceA().getType(), mfma.getSourceB().getType(),
857  mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
858  mfma.getBlocks(), chipset);
859 }
860 
861 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
862 mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
863  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
864  smfma.getSourceB().getType(),
865  smfma.getDestC().getType(), smfma.getM(),
866  smfma.getN(), smfma.getK(), 1u, chipset);
867 }
868 
869 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
870 /// if one exists. This includes checking to ensure the intrinsic is supported
871 /// on the architecture you are compiling for.
872 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
873  Chipset chipset) {
874  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
875  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
876  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
877  auto elemSourceType = sourceVectorType.getElementType();
878  auto elemBSourceType = sourceBVectorType.getElementType();
879  auto elemDestType = destVectorType.getElementType();
880 
881  if (elemSourceType.isF16() && elemDestType.isF32())
882  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
883  if (elemSourceType.isBF16() && elemDestType.isF32())
884  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
885  if (elemSourceType.isF16() && elemDestType.isF16())
886  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
887  if (elemSourceType.isBF16() && elemDestType.isBF16())
888  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
889  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
890  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
891  if (chipset.majorVersion == 11) {
892  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
893  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
894  }
895  if (chipset.majorVersion >= 12) {
896  if (isa<Float8E4M3FNType>(elemSourceType) &&
897  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
898  return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
899  if (isa<Float8E4M3FNType>(elemSourceType) &&
900  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
901  return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
902  if (isa<Float8E5M2Type>(elemSourceType) &&
903  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
904  return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
905  if (isa<Float8E5M2Type>(elemSourceType) &&
906  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
907  return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
908  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
909  bool isWave64 = destVectorType.getNumElements() == 4;
910  // This is the ambiguous case. 8 inputs to the wave64 version means that
911  // we want the 16x16x32 version, but for wave32 they mean the short form.
912  bool has8Inputs = sourceVectorType.getNumElements() == 8;
913  if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
914  return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
915  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
916  }
917  }
918  return std::nullopt;
919 }
920 
921 namespace {
922 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
923  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
924  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
925 
926  Chipset chipset;
927 
928  LogicalResult
929  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
930  ConversionPatternRewriter &rewriter) const override {
931  Location loc = op.getLoc();
932  Type outType = typeConverter->convertType(op.getDestD().getType());
933  Type intrinsicOutType = outType;
934  if (auto outVecType = dyn_cast<VectorType>(outType))
935  if (outVecType.getElementType().isBF16())
936  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
937 
938  if (chipset.majorVersion != 9 || chipset < kGfx908)
939  return op->emitOpError("MFMA only supported on gfx908+");
940  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
941  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
942  if (chipset < kGfx942)
943  return op.emitOpError("negation unsupported on older than gfx942");
944  getBlgpField |=
945  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
946  }
947  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
948  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
949  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
950  if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
951  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
952 
953  bool isScaled =
954  !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
955  if (isScaled &&
956  (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
957  return op.emitOpError(
958  "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
959  "be scaled as those fields are used for type information");
960  }
961 
962  StringRef intrinsicName =
963  isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
964  // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965  // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966  bool allowBf16 = [&]() {
967  if (chipset < kGfx950)
968  return false;
969  if (isScaled)
970  return true;
971  return intrinsicName.contains("16x16x32.bf16") ||
972  intrinsicName.contains("32x32x16.bf16");
973  }();
974  OperationState loweredOp(loc, intrinsicName);
975  loweredOp.addTypes(intrinsicOutType);
976  loweredOp.addOperands({convertMFMAVectorOperand(
977  rewriter, loc, adaptor.getSourceA(), allowBf16),
979  rewriter, loc, adaptor.getSourceB(), allowBf16),
980  adaptor.getDestC()});
981  if (isScaled) {
982  Value zero = createI32Constant(rewriter, loc, 0);
983  auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
984  loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
985  createI32Constant(rewriter, loc, bTypeCode),
986  /*scale A byte=*/zero, /*scale A=*/zero,
987  /*scale B byte=*/zero, /*scale B=*/zero});
988  } else {
989  loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
990  createI32Constant(rewriter, loc, op.getAbid()),
991  createI32Constant(rewriter, loc, getBlgpField)});
992  };
993  Value lowered = rewriter.create(loweredOp)->getResult(0);
994  if (outType != intrinsicOutType)
995  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
996  rewriter.replaceOp(op, lowered);
997  return success();
998  }
999 };
1000 
1001 struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1002  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1003  : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1004 
1005  Chipset chipset;
1006 
1007  LogicalResult
1008  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1009  ConversionPatternRewriter &rewriter) const override {
1010  Location loc = op.getLoc();
1011  Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1012 
1013  if (chipset.majorVersion != 9 || chipset < kGfx950)
1014  return op->emitOpError("scaled MFMA only supported on gfx908+");
1015  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1016  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1017  if (!maybeScaledIntrinsic.has_value())
1018  return op.emitOpError(
1019  "no intrinsic matching scaled MFMA size on given chipset");
1020 
1021  auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1022  OperationState loweredOp(loc, intrinsicName);
1023  loweredOp.addTypes(intrinsicOutType);
1024  loweredOp.addOperands(
1025  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1026  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1027  adaptor.getDestC()});
1028  Value scalesIdxA =
1029  createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1030  Value scalesIdxB =
1031  createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1032  loweredOp.addOperands(
1033  {createI32Constant(rewriter, loc, aTypeCode),
1034  createI32Constant(rewriter, loc, bTypeCode),
1035  /*scales idx A=*/scalesIdxA,
1036  /*scales A*/
1037  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1038  /*scales idx B=*/scalesIdxB,
1039  /*scales B*/
1040  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1041  Value lowered = rewriter.create(loweredOp)->getResult(0);
1042  rewriter.replaceOp(op, lowered);
1043  return success();
1044  }
1045 };
1046 
1047 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1048  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1049  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1050 
1051  Chipset chipset;
1052 
1053  LogicalResult
1054  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1055  ConversionPatternRewriter &rewriter) const override {
1056  Location loc = op.getLoc();
1057  auto outType =
1058  typeConverter->convertType<VectorType>(op.getDestD().getType());
1059  if (!outType)
1060  return rewriter.notifyMatchFailure(op, "type conversion failed");
1061 
1062  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1063  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1064 
1065  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1066  // need to bitcast bfloats to i16 and then bitcast them back.
1067  VectorType rawOutType = outType;
1068  if (outType.getElementType().isBF16())
1069  rawOutType = outType.clone(rewriter.getI16Type());
1070 
1071  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1072 
1073  if (!maybeIntrinsic.has_value())
1074  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1075 
1076  if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1077  return op.emitOpError("subwordOffset not supported on gfx12+");
1078 
1079  OperationState loweredOp(loc, *maybeIntrinsic);
1080  loweredOp.addTypes(rawOutType);
1081 
1082  SmallVector<Value, 4> operands;
1083  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1084  adaptor.getSourceA(), op.getSourceA(), operands);
1085  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1086  adaptor.getSourceB(), op.getSourceB(), operands);
1087  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1088  op.getSubwordOffset(), op.getClamp(), operands);
1089 
1090  loweredOp.addOperands(operands);
1091  Operation *lowered = rewriter.create(loweredOp);
1092 
1093  Operation *maybeCastBack = lowered;
1094  if (rawOutType != outType)
1095  maybeCastBack =
1096  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
1097  rewriter.replaceOp(op, maybeCastBack->getResults());
1098 
1099  return success();
1100  }
1101 };
1102 
1103 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1104  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1105  : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1106 
1107  Chipset chipset;
1108 
1109  LogicalResult
1110  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1111  ConversionPatternRewriter &rewriter) const override {
1112  if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1113  return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1114 
1115  Location loc = op.getLoc();
1116 
1117  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1118  auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1119 
1120  // TODO: instead of only transfering one element per thread, we could
1121  // augment it to transfer multiple elements per thread by issuing multiple
1122  // `global_load_lds` instructions.
1123  Type transferType = op.getTransferType();
1124  size_t loadWidth = [&]() -> size_t {
1125  if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1126  return transferVectorType.getNumElements() *
1127  (transferVectorType.getElementTypeBitWidth() / 8);
1128  }
1129  return transferType.getIntOrFloatBitWidth() / 8;
1130  }();
1131 
1132  // Currently only 1, 2, and 4 byte loads are supported.
1133  if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1134  return op.emitOpError("chipset unsupported element size");
1135 
1136  Value srcPtr =
1137  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1138  (adaptor.getSrcIndices()));
1139  Value dstPtr =
1140  getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1141  (adaptor.getDstIndices()));
1142 
1143  rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1144  op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1145  /*offset=*/rewriter.getI32IntegerAttr(0),
1146  /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1147  ArrayAttr{});
1148 
1149  return success();
1150  }
1151 };
1152 
1153 namespace {
1154 struct ExtPackedFp8OpLowering final
1155  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1156  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1157  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1158  chipset(chipset) {}
1159  Chipset chipset;
1160 
1161  LogicalResult
1162  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1163  ConversionPatternRewriter &rewriter) const override;
1164 };
1165 
1166 struct PackedTrunc2xFp8OpLowering final
1167  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1168  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1169  Chipset chipset)
1170  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1171  chipset(chipset) {}
1172  Chipset chipset;
1173 
1174  LogicalResult
1175  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1176  ConversionPatternRewriter &rewriter) const override;
1177 };
1178 
1179 struct PackedStochRoundFp8OpLowering final
1180  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1181  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1182  Chipset chipset)
1183  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1184  chipset(chipset) {}
1185  Chipset chipset;
1186 
1187  LogicalResult
1188  matchAndRewrite(PackedStochRoundFp8Op op,
1189  PackedStochRoundFp8OpAdaptor adaptor,
1190  ConversionPatternRewriter &rewriter) const override;
1191 };
1192 
1193 struct ScaledExtPackedOpLowering final
1194  : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1195  ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1196  : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1197  chipset(chipset) {}
1198  Chipset chipset;
1199 
1200  LogicalResult
1201  matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1202  ConversionPatternRewriter &rewriter) const override;
1203 };
1204 
1205 struct PackedScaledTruncOpLowering final
1206  : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1207  PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1208  Chipset chipset)
1209  : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1210  chipset(chipset) {}
1211  Chipset chipset;
1212 
1213  LogicalResult
1214  matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1215  ConversionPatternRewriter &rewriter) const override;
1216 };
1217 
1218 } // end namespace
1219 
1220 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1221  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1222  ConversionPatternRewriter &rewriter) const {
1223  Location loc = op.getLoc();
1224  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1225  return rewriter.notifyMatchFailure(
1226  loc, "Fp8 conversion instructions are not available on target "
1227  "architecture and their emulation is not implemented");
1228  Type v4i8 =
1229  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1230  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1231  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1232 
1233  Value source = adaptor.getSource();
1234  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1235  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1236  Type sourceElemType = getElementTypeOrSelf(op.getSource());
1237  // Extend to a v4i8
1238  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1239  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
1240  if (!sourceVecType) {
1241  longVec = rewriter.create<LLVM::InsertElementOp>(
1242  loc, longVec, source, createI32Constant(rewriter, loc, 0));
1243  } else {
1244  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1245  Value idx = createI32Constant(rewriter, loc, i);
1246  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1247  longVec =
1248  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1249  }
1250  }
1251  source = longVec;
1252  }
1253  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1254  if (resultVecType) {
1255  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1256  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1257  op.getIndex());
1258  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1259  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1260  op.getIndex());
1261  }
1262  } else {
1263  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1264  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1265  op.getIndex());
1266  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1267  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1268  op.getIndex());
1269  }
1270  }
1271  return success();
1272 }
1273 
1274 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1275  ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1276  ConversionPatternRewriter &rewriter) const {
1277  Location loc = op.getLoc();
1278  if (chipset != kGfx950)
1279  return rewriter.notifyMatchFailure(
1280  loc, "Scaled fp conversion instructions are not available on target "
1281  "architecture and their emulation is not implemented");
1282  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1283 
1284  Value source = adaptor.getSource();
1285  Value scale = adaptor.getScale();
1286 
1287  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1288  Type sourceElemType = sourceVecType.getElementType();
1289  VectorType destVecType = cast<VectorType>(op.getResult().getType());
1290  Type destElemType = destVecType.getElementType();
1291 
1292  VectorType packedVecType;
1293  if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1294  VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1295  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1296  } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1297  VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1298  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1299  } else {
1300  llvm_unreachable("invalid element type for scaled ext");
1301  }
1302 
1303  // Extend to a packedVectorType
1304  if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1305  Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
1306  if (!sourceVecType) {
1307  longVec = rewriter.create<LLVM::InsertElementOp>(
1308  loc, longVec, source, createI32Constant(rewriter, loc, 0));
1309  } else {
1310  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1311  Value idx = createI32Constant(rewriter, loc, i);
1312  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1313  longVec =
1314  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1315  }
1316  }
1317  source = longVec;
1318  }
1319  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1320 
1321  if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1322  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1323  op, destVecType, i32Source, scale, op.getIndex());
1324  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1325  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1326  op, destVecType, i32Source, scale, op.getIndex());
1327  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1328  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1329  op, destVecType, i32Source, scale, op.getIndex());
1330  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1331  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1332  op, destVecType, i32Source, scale, op.getIndex());
1333  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1334  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1335  op, destVecType, i32Source, scale, op.getIndex());
1336  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1337  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1338  op, destVecType, i32Source, scale, op.getIndex());
1339  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1340  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1341  op, destVecType, i32Source, scale, op.getIndex());
1342  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1343  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1344  op, destVecType, i32Source, scale, op.getIndex());
1345  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1346  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1347  op, destVecType, i32Source, scale, op.getIndex());
1348  else
1349  return failure();
1350 
1351  return success();
1352 }
1353 
1354 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1355  PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1356  ConversionPatternRewriter &rewriter) const {
1357  Location loc = op.getLoc();
1358  if (chipset != kGfx950)
1359  return rewriter.notifyMatchFailure(
1360  loc, "Scaled fp conversion instructions are not available on target "
1361  "architecture and their emulation is not implemented");
1362  Type v2i16 = getTypeConverter()->convertType(
1363  VectorType::get(2, rewriter.getI16Type()));
1364  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1365 
1366  Type resultType = op.getResult().getType();
1367  Type resultElemType = getElementTypeOrSelf(resultType);
1368  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1369  Type sourceElemType = sourceVecType.getElementType();
1370 
1371  Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1372 
1373  Value source = adaptor.getSource();
1374  Value scale = adaptor.getScale();
1375  Value existing = adaptor.getExisting();
1376  if (existing)
1377  existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
1378  else
1379  existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
1380 
1381  if (sourceVecType.getNumElements() < 2) {
1382  Value c0 = createI32Constant(rewriter, loc, 0);
1383  Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1384  VectorType v2 = VectorType::get(2, sourceElemType);
1385  source = rewriter.create<LLVM::ZeroOp>(loc, v2);
1386  source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
1387  }
1388 
1389  Value sourceA, sourceB;
1390  if (sourceElemType.isF32()) {
1391  Value c0 = createI32Constant(rewriter, loc, 0);
1392  Value c1 = createI32Constant(rewriter, loc, 1);
1393  sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1394  sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
1395  }
1396 
1397  Value result;
1398  if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1399  result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
1400  loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1401  else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1402  result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
1403  loc, intResultType, existing, source, scale, op.getIndex());
1404  else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1405  result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
1406  loc, intResultType, existing, source, scale, op.getIndex());
1407  else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1408  result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
1409  loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1410  else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1411  result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
1412  loc, intResultType, existing, source, scale, op.getIndex());
1413  else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1414  result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
1415  loc, intResultType, existing, source, scale, op.getIndex());
1416  else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1417  result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
1418  loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1419  else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1420  result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
1421  loc, intResultType, existing, source, scale, op.getIndex());
1422  else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1423  result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
1424  loc, intResultType, existing, source, scale, op.getIndex());
1425  else
1426  return failure();
1427 
1428  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1429  op, getTypeConverter()->convertType(resultType), result);
1430  return success();
1431 }
1432 
1433 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1434  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1435  ConversionPatternRewriter &rewriter) const {
1436  Location loc = op.getLoc();
1437  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1438  return rewriter.notifyMatchFailure(
1439  loc, "Fp8 conversion instructions are not available on target "
1440  "architecture and their emulation is not implemented");
1441  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1442 
1443  Type resultType = op.getResult().getType();
1444  Type resultElemType = getElementTypeOrSelf(resultType);
1445 
1446  Value sourceA = adaptor.getSourceA();
1447  Value sourceB = adaptor.getSourceB();
1448  if (!sourceB)
1449  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
1450  Value existing = adaptor.getExisting();
1451  if (existing)
1452  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1453  else
1454  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1455 
1456  Value result;
1457  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1458  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1459  existing, op.getWordIndex());
1460  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1461  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1462  existing, op.getWordIndex());
1463 
1464  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1465  op, getTypeConverter()->convertType(resultType), result);
1466  return success();
1467 }
1468 
1469 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1470  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1471  ConversionPatternRewriter &rewriter) const {
1472  Location loc = op.getLoc();
1473  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1474  return rewriter.notifyMatchFailure(
1475  loc, "Fp8 conversion instructions are not available on target "
1476  "architecture and their emulation is not implemented");
1477  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1478 
1479  Type resultType = op.getResult().getType();
1480  Type resultElemType = getElementTypeOrSelf(resultType);
1481 
1482  Value source = adaptor.getSource();
1483  Value stoch = adaptor.getStochiasticParam();
1484  Value existing = adaptor.getExisting();
1485  if (existing)
1486  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1487  else
1488  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1489 
1490  Value result;
1491  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1492  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
1493  loc, i32, source, stoch, existing, op.getStoreIndex());
1494  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1495  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
1496  loc, i32, source, stoch, existing, op.getStoreIndex());
1497 
1498  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1499  op, getTypeConverter()->convertType(resultType), result);
1500  return success();
1501 }
1502 
1503 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1504 // operation into the corresponding ROCDL instructions.
1505 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1506  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1507  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1508  Chipset chipset;
1509 
1510  LogicalResult
1511  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1512  ConversionPatternRewriter &rewriter) const override {
1513 
1514  // Convert the source operand to the corresponding LLVM type
1515  Location loc = DppOp.getLoc();
1516  Value src = adaptor.getSrc();
1517  Value old = adaptor.getOld();
1518  Type srcType = src.getType();
1519  Type oldType = old.getType();
1520  Type llvmType = nullptr;
1521  if (srcType.getIntOrFloatBitWidth() < 32) {
1522  llvmType = rewriter.getI32Type();
1523  } else if (isa<FloatType>(srcType)) {
1524  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1525  ? rewriter.getF32Type()
1526  : rewriter.getF64Type();
1527  } else if (isa<IntegerType>(srcType)) {
1528  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1529  ? rewriter.getI32Type()
1530  : rewriter.getI64Type();
1531  }
1532  auto llvmSrcIntType = typeConverter->convertType(
1533  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1534 
1535  // If the source type is less of 32, use bitcast to convert it to i32.
1536  auto convertOperand = [&](Value operand, Type operandType) {
1537  if (operandType.getIntOrFloatBitWidth() <= 16) {
1538  if (llvm::isa<FloatType>(operandType)) {
1539  operand =
1540  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1541  }
1542  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1543  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1544  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
1545  operand = rewriter.create<LLVM::InsertElementOp>(
1546  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
1547  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
1548  }
1549  return operand;
1550  };
1551 
1552  src = convertOperand(src, srcType);
1553  old = convertOperand(old, oldType);
1554 
1555  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1556  enum DppCtrl : unsigned {
1557  ROW_SHL0 = 0x100,
1558  ROW_SHR0 = 0x110,
1559  ROW_ROR0 = 0x120,
1560  WAVE_SHL1 = 0x130,
1561  WAVE_ROL1 = 0x134,
1562  WAVE_SHR1 = 0x138,
1563  WAVE_ROR1 = 0x13C,
1564  ROW_MIRROR = 0x140,
1565  ROW_HALF_MIRROR = 0x141,
1566  BCAST15 = 0x142,
1567  BCAST31 = 0x143,
1568  };
1569 
1570  auto kind = DppOp.getKind();
1571  auto permArgument = DppOp.getPermArgument();
1572  uint32_t DppCtrl = 0;
1573 
1574  switch (kind) {
1575 
1576  case DPPPerm::quad_perm:
1577  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1578  int32_t i = 0;
1579  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1580  uint32_t num = elem.getInt();
1581  DppCtrl |= num << (i * 2);
1582  i++;
1583  }
1584  }
1585  break;
1586  case DPPPerm::row_shl:
1587  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1588  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1589  }
1590  break;
1591  case DPPPerm::row_shr:
1592  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1593  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1594  }
1595  break;
1596  case DPPPerm::row_ror:
1597  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1598  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1599  }
1600  break;
1601  case DPPPerm::wave_shl:
1602  DppCtrl = DppCtrl::WAVE_SHL1;
1603  break;
1604  case DPPPerm::wave_shr:
1605  DppCtrl = DppCtrl::WAVE_SHR1;
1606  break;
1607  case DPPPerm::wave_rol:
1608  DppCtrl = DppCtrl::WAVE_ROL1;
1609  break;
1610  case DPPPerm::wave_ror:
1611  DppCtrl = DppCtrl::WAVE_ROR1;
1612  break;
1613  case DPPPerm::row_mirror:
1614  DppCtrl = DppCtrl::ROW_MIRROR;
1615  break;
1616  case DPPPerm::row_half_mirror:
1617  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1618  break;
1619  case DPPPerm::row_bcast_15:
1620  DppCtrl = DppCtrl::BCAST15;
1621  break;
1622  case DPPPerm::row_bcast_31:
1623  DppCtrl = DppCtrl::BCAST31;
1624  break;
1625  }
1626 
1627  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1628  // constants
1629  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1630  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1631  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1632 
1633  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1634  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1635  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1636 
1637  Value result = dppMovOp.getRes();
1638  if (srcType.getIntOrFloatBitWidth() < 32) {
1639  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1640  if (!llvm::isa<IntegerType>(srcType)) {
1641  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1642  }
1643  }
1644 
1645  // We are replacing the AMDGPU_DPPOp instruction with the new
1646  // ROCDL_DPPMovOp instruction
1647  rewriter.replaceOp(DppOp, ValueRange(result));
1648  return success();
1649  }
1650 };
1651 
1652 struct AMDGPUSwizzleBitModeLowering
1653  : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1655 
1656  LogicalResult
1657  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1658  ConversionPatternRewriter &rewriter) const override {
1659  Location loc = op.getLoc();
1660  Type i32 = rewriter.getI32Type();
1661  Value src = adaptor.getSrc();
1662  SmallVector<Value> decomposed =
1663  LLVM::decomposeValue(rewriter, loc, src, i32);
1664  unsigned andMask = op.getAndMask();
1665  unsigned orMask = op.getOrMask();
1666  unsigned xorMask = op.getXorMask();
1667 
1668  // bit 15 is 0 for the BitMode swizzle.
1669  // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1670  unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1671  Value maskValue = createI32Constant(rewriter, loc, mask);
1672  SmallVector<Value> swizzled;
1673  for (Value v : decomposed) {
1674  Value res =
1675  rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1676  swizzled.emplace_back(res);
1677  }
1678 
1679  Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1680  rewriter.replaceOp(op, result);
1681  return success();
1682  }
1683 };
1684 
1685 struct ConvertAMDGPUToROCDLPass
1686  : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1687  using Base::Base;
1688 
1689  void runOnOperation() override {
1690  MLIRContext *ctx = &getContext();
1691  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1692  if (failed(maybeChipset)) {
1693  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1694  return signalPassFailure();
1695  }
1696 
1698  LLVMTypeConverter converter(ctx);
1699  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1700  LLVMConversionTarget target(getContext());
1701  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1702  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1703  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1704  if (failed(applyPartialConversion(getOperation(), target,
1705  std::move(patterns))))
1706  signalPassFailure();
1707  }
1708 };
1709 } // namespace
1710 
1712  TypeConverter &typeConverter) {
1713  typeConverter.addTypeAttributeConversion(
1714  [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1716  MLIRContext *ctx = as.getContext();
1717  Type i64 = IntegerType::get(ctx, 64);
1718  switch (as.getValue()) {
1719  case amdgpu::AddressSpace::FatRawBuffer:
1720  return IntegerAttr::get(i64, 7);
1721  case amdgpu::AddressSpace::BufferRsrc:
1722  return IntegerAttr::get(i64, 8);
1723  case amdgpu::AddressSpace::FatStructuredBuffer:
1724  return IntegerAttr::get(i64, 9);
1725  }
1727  });
1728 }
1729 
1732  Chipset chipset) {
1734  patterns
1735  .add<FatRawBufferCastLowering,
1736  RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1737  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1738  RawBufferOpLowering<RawBufferAtomicFaddOp,
1739  ROCDL::RawPtrBufferAtomicFaddOp>,
1740  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1741  ROCDL::RawPtrBufferAtomicFmaxOp>,
1742  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1743  ROCDL::RawPtrBufferAtomicSmaxOp>,
1744  RawBufferOpLowering<RawBufferAtomicUminOp,
1745  ROCDL::RawPtrBufferAtomicUminOp>,
1746  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1747  ROCDL::RawPtrBufferAtomicCmpSwap>,
1748  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1749  MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1750  ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1751  PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752  PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1753  chipset);
1754  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
1755 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
@ None
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
IntegerType getI16Type()
Definition: Builders.cpp:63
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
FloatType getF32Type()
Definition: Builders.cpp:45
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:215
IntegerType getI64Type()
Definition: Builders.cpp:67
IntegerType getI32Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
IntegerType getI4Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:55
IntegerType getI8Type()
Definition: Builders.cpp:61
FloatType getF64Type()
Definition: Builders.cpp:47
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:195
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:201
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
The general result of a type attribute conversion callback, allowing for early termination.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:487
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:448
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:409
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14