MLIR  21.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include "../LLVMCommon/MemRefDescriptor.h"
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include <optional>
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::amdgpu;
37 
38 // Define commonly used chipsets versions for convenience.
39 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
40 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
41 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
42 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
43 
44 /// Convert an unsigned number `val` to i32.
46  Location loc, Value val) {
47  IntegerType i32 = rewriter.getI32Type();
48  // Force check that `val` is of int type.
49  auto valTy = cast<IntegerType>(val.getType());
50  if (i32 == valTy)
51  return val;
52  return valTy.getWidth() > 32
53  ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
54  : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
55 }
56 
58  Location loc, int32_t value) {
59  Type i32 = rewriter.getI32Type();
60  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
61 }
62 
64  bool value) {
65  Type llvmI1 = rewriter.getI1Type();
66  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
67 }
68 
69 /// Returns the linear index used to access an element in the memref.
71  Location loc, MemRefDescriptor &memRefDescriptor,
72  ValueRange indices, ArrayRef<int64_t> strides) {
73  IntegerType i32 = rewriter.getI32Type();
74  Value index;
75  for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
76  if (stride != 1) { // Skip if stride is 1.
77  Value strideValue =
78  ShapedType::isDynamic(stride)
79  ? convertUnsignedToI32(rewriter, loc,
80  memRefDescriptor.stride(rewriter, loc, i))
81  : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
82  increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
83  }
84  index =
85  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
86  }
87  return index ? index : createI32Constant(rewriter, loc, 0);
88 }
89 
90 /// Compute the contents of the `num_records` field for a given memref
91 /// descriptor - that is, the number of bytes that's one element past the
92 /// greatest possible valid index into the memref.
94  MemRefType memrefType,
95  MemRefDescriptor &memrefDescriptor,
96  ArrayRef<int64_t> strides,
97  uint32_t elementByteWidth) {
98  if (memrefType.hasStaticShape() &&
99  !llvm::any_of(strides, ShapedType::isDynamic)) {
100  int64_t size = memrefType.getRank() == 0 ? 1 : 0;
101  ArrayRef<int64_t> shape = memrefType.getShape();
102  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
103  size = std::max(shape[i] * strides[i], size);
104  size = size * elementByteWidth;
105  assert(size < std::numeric_limits<uint32_t>::max() &&
106  "the memref buffer is too large");
107  return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
108  }
109  Value maxIndex;
110  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
111  Value size = memrefDescriptor.size(rewriter, loc, i);
112  Value stride = memrefDescriptor.stride(rewriter, loc, i);
113  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
114  maxIndex = maxIndex
115  ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
116  : maxThisDim;
117  }
118  Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
119  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
120  return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
121 }
122 
124  Value basePointer, Value numRecords,
125  bool boundsCheck, amdgpu::Chipset chipset,
126  Value cacheSwizzleStride = nullptr,
127  unsigned addressSpace = 8) {
128  // The stride value is generally 0. However, on MI-300 and onward, you can
129  // enable a cache swizzling mode by setting bit 14 of the stride field
130  // and setting that stride to a cache stride.
131  Type i16 = rewriter.getI16Type();
132  Value stride;
133  if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
134  Value cacheStrideZext =
135  rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
136  Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
137  loc, i16, rewriter.getI16IntegerAttr(1 << 14));
138  stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
139  /*isDisjoint=*/true);
140  } else {
141  stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
142  rewriter.getI16IntegerAttr(0));
143  }
144  // Get the number of elements.
145  // Flag word:
146  // bits 0-11: dst sel, ignored by these intrinsics
147  // bits 12-14: data format (ignored, must be nonzero, 7=float)
148  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
149  // bit 19: In nested heap (0 here)
150  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
151  // bits 21-22: Index stride for swizzles (N/A)
152  // bit 23: Add thread ID (0)
153  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
154  // bits 25-26: Reserved (0)
155  // bit 27: Buffer is non-volatile (CDNA only)
156  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
157  // none, 3 = either swizzles or testing against offset field) RDNA only
158  // bits 30-31: Type (must be 0)
159  uint32_t flags = (7 << 12) | (4 << 15);
160  if (chipset.majorVersion >= 10) {
161  flags |= (1 << 24);
162  uint32_t oob = boundsCheck ? 3 : 2;
163  flags |= (oob << 28);
164  }
165  Value flagsConst = createI32Constant(rewriter, loc, flags);
166  Type rsrcType =
167  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
168  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
169  loc, rsrcType, basePointer, stride, numRecords, flagsConst);
170  return resource;
171 }
172 
173 namespace {
174 struct FatRawBufferCastLowering
175  : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
176  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
177  : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
178  chipset(chipset) {}
179 
180  Chipset chipset;
181 
182  LogicalResult
183  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const override {
185  Location loc = op.getLoc();
186  Value memRef = adaptor.getSource();
187  Value unconvertedMemref = op.getSource();
188  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
189  MemRefDescriptor descriptor(memRef);
190 
191  DataLayout dataLayout = DataLayout::closest(op);
192  int64_t elementByteWidth =
193  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
194 
195  int64_t unusedOffset = 0;
196  SmallVector<int64_t, 5> strideVals;
197  if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
198  return op.emitOpError("Can't lower non-stride-offset memrefs");
199 
200  Value numRecords = adaptor.getValidBytes();
201  if (!numRecords)
202  numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
203  strideVals, elementByteWidth);
204 
205  Value basePointer =
206  adaptor.getResetOffset()
207  ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
208  memrefType)
209  : descriptor.alignedPtr(rewriter, loc);
210 
211  Value offset = adaptor.getResetOffset()
212  ? rewriter.create<LLVM::ConstantOp>(
213  loc, getIndexType(), rewriter.getIndexAttr(0))
214  : descriptor.offset(rewriter, loc);
215 
216  bool hasSizes = memrefType.getRank() > 0;
217  // No need to unpack() and pack() all the individual sizes and strides,
218  // so we'll just extract the arrays.
219  Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
220  loc, descriptor, kSizePosInMemRefDescriptor)
221  : Value{};
222  Value strides = hasSizes
223  ? rewriter.create<LLVM::ExtractValueOp>(
224  loc, descriptor, kStridePosInMemRefDescriptor)
225  : Value{};
226 
227  Value fatPtr = makeBufferRsrc(
228  rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
229  chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
230 
232  rewriter, loc,
233  getTypeConverter()->convertType(op.getResult().getType()));
234  result = rewriter.create<LLVM::InsertValueOp>(
235  loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
236  result = rewriter.create<LLVM::InsertValueOp>(
237  loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
238  result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
240  if (hasSizes) {
241  result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
243  result = rewriter.create<LLVM::InsertValueOp>(
244  loc, result, strides, kStridePosInMemRefDescriptor);
245  }
246  rewriter.replaceOp(op, result);
247  return success();
248  }
249 };
250 
251 /// Define lowering patterns for raw buffer ops
252 template <typename GpuOp, typename Intrinsic>
253 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
254  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
255  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
256 
257  Chipset chipset;
258  static constexpr uint32_t maxVectorOpWidth = 128;
259 
260  LogicalResult
261  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
262  ConversionPatternRewriter &rewriter) const override {
263  Location loc = gpuOp.getLoc();
264  Value memref = adaptor.getMemref();
265  Value unconvertedMemref = gpuOp.getMemref();
266  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
267 
268  if (chipset.majorVersion < 9)
269  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
270 
271  Value storeData = adaptor.getODSOperands(0)[0];
272  if (storeData == memref) // no write component to this op
273  storeData = Value();
274  Type wantedDataType;
275  if (storeData)
276  wantedDataType = storeData.getType();
277  else
278  wantedDataType = gpuOp.getODSResults(0)[0].getType();
279 
280  Value atomicCmpData = Value();
281  // Operand index 1 of a load is the indices, trying to read them can crash.
282  if (storeData) {
283  Value maybeCmpData = adaptor.getODSOperands(1)[0];
284  if (maybeCmpData != memref)
285  atomicCmpData = maybeCmpData;
286  }
287 
288  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
289 
290  Type i32 = rewriter.getI32Type();
291 
292  // Get the type size in bytes.
293  DataLayout dataLayout = DataLayout::closest(gpuOp);
294  int64_t elementByteWidth =
295  dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
296  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
297 
298  // If we want to load a vector<NxT> with total size <= 32
299  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
300  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
301  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
302  // so bitcast any floats to integers.
303  Type llvmBufferValType = llvmWantedDataType;
304  if (atomicCmpData) {
305  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
306  llvmBufferValType = this->getTypeConverter()->convertType(
307  rewriter.getIntegerType(floatType.getWidth()));
308  }
309  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
310  uint32_t vecLen = dataVector.getNumElements();
311  uint32_t elemBits =
312  dataLayout.getTypeSizeInBits(dataVector.getElementType());
313  uint32_t totalBits = elemBits * vecLen;
314  bool usePackedFp16 =
315  isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
316  if (totalBits > maxVectorOpWidth)
317  return gpuOp.emitOpError(
318  "Total width of loads or stores must be no more than " +
319  Twine(maxVectorOpWidth) + " bits, but we call for " +
320  Twine(totalBits) +
321  " bits. This should've been caught in validation");
322  if (!usePackedFp16 && elemBits < 32) {
323  if (totalBits > 32) {
324  if (totalBits % 32 != 0)
325  return gpuOp.emitOpError("Load or store of more than 32-bits that "
326  "doesn't fit into words. Can't happen\n");
327  llvmBufferValType = this->typeConverter->convertType(
328  VectorType::get(totalBits / 32, i32));
329  } else {
330  llvmBufferValType = this->typeConverter->convertType(
331  rewriter.getIntegerType(totalBits));
332  }
333  }
334  }
335  if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
336  // Buffer intrinsics doesn't support 1-element vectors, cast them to
337  // scalars.
338  if (vecType.getNumElements() == 1)
339  llvmBufferValType = vecType.getElementType();
340  }
341 
343  if (storeData) {
344  if (llvmBufferValType != llvmWantedDataType) {
345  Value castForStore =
346  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
347  args.push_back(castForStore);
348  } else {
349  args.push_back(storeData);
350  }
351  }
352 
353  if (atomicCmpData) {
354  if (llvmBufferValType != llvmWantedDataType) {
355  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
356  loc, llvmBufferValType, atomicCmpData);
357  args.push_back(castForCmp);
358  } else {
359  args.push_back(atomicCmpData);
360  }
361  }
362 
363  // Construct buffer descriptor from memref, attributes
364  int64_t offset = 0;
365  SmallVector<int64_t, 5> strides;
366  if (failed(memrefType.getStridesAndOffset(strides, offset)))
367  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
368 
369  MemRefDescriptor memrefDescriptor(memref);
370 
371  Value ptr = memrefDescriptor.bufferPtr(
372  rewriter, loc, *this->getTypeConverter(), memrefType);
373  Value numRecords = getNumRecords(
374  rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
375  Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
376  adaptor.getBoundsCheck(), chipset);
377  args.push_back(resource);
378 
379  // Indexing (voffset)
380  Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
381  adaptor.getIndices(), strides);
382  if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
383  indexOffset && *indexOffset > 0) {
384  Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
385  voffset =
386  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
387  : extraOffsetConst;
388  }
389  voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
390  args.push_back(voffset);
391 
392  // SGPR offset.
393  Value sgprOffset = adaptor.getSgprOffset();
394  if (!sgprOffset)
395  sgprOffset = createI32Constant(rewriter, loc, 0);
396  sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
397  args.push_back(sgprOffset);
398 
399  // bit 0: GLC = 0 (atomics drop value, less coherency)
400  // bits 1-2: SLC, DLC = 0 (similarly)
401  // bit 3: swizzled (0 for raw)
402  args.push_back(createI32Constant(rewriter, loc, 0));
403 
404  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
405  llvmBufferValType);
406  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
408  if (lowered->getNumResults() == 1) {
409  Value replacement = lowered->getResult(0);
410  if (llvmBufferValType != llvmWantedDataType) {
411  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
412  replacement);
413  }
414  rewriter.replaceOp(gpuOp, replacement);
415  } else {
416  rewriter.eraseOp(gpuOp);
417  }
418  return success();
419  }
420 };
421 
422 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
423  LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
424  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
425 
426  Chipset chipset;
427 
428  LogicalResult
429  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
430  ConversionPatternRewriter &rewriter) const override {
431  bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
432 
433  if (requiresInlineAsm) {
434  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
435  LLVM::AsmDialect::AD_ATT);
436  const char *asmStr =
437  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
438  const char *constraints = "";
439  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
440  op,
441  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
442  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
443  /*is_align_stack=*/false, LLVM::TailCallKind::None,
444  /*asm_dialect=*/asmDialectAttr,
445  /*operand_attrs=*/ArrayAttr());
446  return success();
447  }
448  if (chipset.majorVersion < 12) {
449  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
450  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
451  // Left in place in case someone disables the inline ASM path or future
452  // chipsets use the same bit pattern.
453  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
454 
455  int32_t ldsOnlyBits;
456  if (chipset.majorVersion == 11)
457  ldsOnlyBits = ldsOnlyBitsGfx11;
458  else if (chipset.majorVersion == 10)
459  ldsOnlyBits = ldsOnlyBitsGfx10;
460  else if (chipset.majorVersion <= 9)
461  ldsOnlyBits = ldsOnlyBitsGfx6789;
462  else
463  return op.emitOpError(
464  "don't know how to lower this for chipset major version")
465  << chipset.majorVersion;
466 
467  Location loc = op->getLoc();
468  rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
469  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
470  } else {
471  Location loc = op->getLoc();
472  rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
473  rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
474  rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
475  }
476 
477  return success();
478  }
479 };
480 
481 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
482  SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
483  : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
484 
485  Chipset chipset;
486 
487  LogicalResult
488  matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
489  ConversionPatternRewriter &rewriter) const override {
490  rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
491  (uint32_t)op.getOpts());
492  return success();
493  }
494 };
495 
496 } // namespace
497 
498 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
499 /// and LLVM AMDGPU intrinsics convention.
500 ///
501 /// Specifically:
502 /// 1. If the element type is bfloat16, bitcast it to i16.
503 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
504 /// instead, which is what the f8f6f4 intrinsics use.
505 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
506 /// integer.
507 ///
508 /// Note that the type of `input` has already been LLVM type converted:
509 /// therefore 8-bit and smaller floats are represented as their corresponding
510 /// `iN` integers.
512  Location loc, Value input) {
513  Type inputType = input.getType();
514  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
515  if (vectorType.getElementType().isBF16())
516  return rewriter.create<LLVM::BitcastOp>(
517  loc, vectorType.clone(rewriter.getI16Type()), input);
518  if (vectorType.getElementType().isInteger(8) &&
519  vectorType.getNumElements() <= 8)
520  return rewriter.create<LLVM::BitcastOp>(
521  loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
522  if (isa<IntegerType>(vectorType.getElementType()) &&
523  vectorType.getElementTypeBitWidth() <= 8) {
524  int64_t numWords = llvm::divideCeil(
525  vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
526  32);
527  return rewriter.create<LLVM::BitcastOp>(
528  loc, VectorType::get(numWords, rewriter.getI32Type()), input);
529  }
530  }
531  return input;
532 }
533 
534 /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
535 /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
536 ///
537 /// Specifically:
538 /// 1. If `input` is a i8 value, zero extend it to i32
539 /// 2. If `input` is a vector of length 4 and type i8, cast it to i32
540 ///
541 /// Note that the type of `input` has already been LLVM type converted:
542 /// therefore 8-bit and smaller floats are represented as their corresponding
543 /// `iN` integers.
545  Location loc, Value input) {
546  Type inputType = input.getType();
547  Type outputType = rewriter.getI32Type();
548  if (auto intType = dyn_cast<IntegerType>(inputType))
549  return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
550  return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
551 }
552 
553 /// Push an input operand. If it is a float type, nothing to do. If it is
554 /// an integer type, then we need to also push its signdness (1 for signed, 0
555 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
556 /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
557 /// We also need to convert bfloat inputs to i16 to account for the bfloat
558 /// intrinsics having been defined before the AMD backend supported bfloat. We
559 /// similarly need to pack 8-bit float types into integers as if they were i8
560 /// (which they are for the backend's purposes).
562  Location loc,
563  const TypeConverter *typeConverter,
564  bool isUnsigned, Value llvmInput,
565  Value mlirInput,
566  SmallVector<Value, 4> &operands) {
567  Type inputType = llvmInput.getType();
568  auto vectorType = dyn_cast<VectorType>(inputType);
569  if (!vectorType) {
570  operands.push_back(llvmInput);
571  return;
572  }
573  Type elemType = vectorType.getElementType();
574 
575  if (elemType.isBF16())
576  llvmInput = rewriter.create<LLVM::BitcastOp>(
577  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
578  if (elemType.getIntOrFloatBitWidth() > 8) {
579  operands.push_back(llvmInput);
580  return;
581  }
582 
583  // We need to check the type of the input before conversion to properly test
584  // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
585  // fp8/int8 information is lost during the conversion process.
586  auto mlirInputType = cast<VectorType>(mlirInput.getType());
587  bool isInputInteger = mlirInputType.getElementType().isInteger();
588  if (isInputInteger) {
589  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
590  bool localIsUnsigned = isUnsigned;
591  if (elemType.isUnsignedInteger()) {
592  localIsUnsigned = true;
593  } else if (elemType.isSignedInteger()) {
594  localIsUnsigned = false;
595  }
596  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
597  operands.push_back(sign);
598  }
599 
600  int64_t numBits =
601  vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
602  Type i32 = rewriter.getI32Type();
603  Type intrinsicInType = numBits <= 32
604  ? (Type)rewriter.getIntegerType(numBits)
605  : (Type)VectorType::get(numBits / 32, i32);
606  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
607  Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
608  loc, llvmIntrinsicInType, llvmInput);
609  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
610  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
611  // Add in the zeros here.
612  if (numBits < 32)
613  castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
614  operands.push_back(castInput);
615 }
616 
617 /// Push the output operand. For many cases this is only pushing the output in
618 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
619 /// since the same numbers of VGPRs is used, we need to decide if to store the
620 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
621 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
622 /// be stored it in the upper part. The subwordOffset must not be set for gfx12,
623 /// as the instructions have been changed to return fewer registers instead.
625  Location loc,
626  const TypeConverter *typeConverter,
627  Value output, int32_t subwordOffset,
628  bool clamp, SmallVector<Value, 4> &operands) {
629  Type inputType = output.getType();
630  auto vectorType = dyn_cast<VectorType>(inputType);
631  Type elemType = vectorType.getElementType();
632  if (elemType.isBF16())
633  output = rewriter.create<LLVM::BitcastOp>(
634  loc, vectorType.clone(rewriter.getI16Type()), output);
635  operands.push_back(output);
636  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
637  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
638  } else if (elemType.isInteger(32)) {
639  operands.push_back(createI1Constant(rewriter, loc, clamp));
640  }
641 }
642 
643 /// Return true if `type` is the E5M2 variant of an 8-bit float that is
644 /// supported by the `_bf8` instructions on the given `chipset`.
645 static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
646  return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
647  (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
648 }
649 
650 /// Return true if `type` is the E4M3FN variant of an 8-bit float that is
651 /// supported by the `_fp8` instructions on the given `chipset`.
652 static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
653  return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
654  (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
655 }
656 
657 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
658 /// if one exists. This includes checking to ensure the intrinsic is supported
659 /// on the architecture you are compiling for.
660 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
661  Chipset chipset) {
662  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
663  b = mfma.getBlocks();
664  Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
665  Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
666 
667  if (sourceElem.isF32() && destElem.isF32()) {
668  if (mfma.getReducePrecision() && chipset >= kGfx942) {
669  if (m == 32 && n == 32 && k == 4 && b == 1)
670  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
671  if (m == 16 && n == 16 && k == 8 && b == 1)
672  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
673  }
674  if (m == 32 && n == 32 && k == 1 && b == 2)
675  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
676  if (m == 16 && n == 16 && k == 1 && b == 4)
677  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
678  if (m == 4 && n == 4 && k == 1 && b == 16)
679  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
680  if (m == 32 && n == 32 && k == 2 && b == 1)
681  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
682  if (m == 16 && n == 16 && k == 4 && b == 1)
683  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
684  }
685 
686  if (sourceElem.isF16() && destElem.isF32()) {
687  if (chipset >= kGfx950) {
688  if (m == 32 && n == 32 && k == 16 && b == 1)
689  return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
690  if (m == 16 && n == 16 && k == 32 && b == 1)
691  return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
692  }
693  if (m == 32 && n == 32 && k == 4 && b == 2)
694  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
695  if (m == 16 && n == 16 && k == 4 && b == 4)
696  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
697  if (m == 4 && n == 4 && k == 4 && b == 16)
698  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
699  if (m == 32 && n == 32 && k == 8 && b == 1)
700  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
701  if (m == 16 && n == 16 && k == 16 && b == 1)
702  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
703  }
704 
705  if (sourceElem.isBF16() && destElem.isF32()) {
706  if (chipset >= kGfx950) {
707  if (m == 32 && n == 32 && k == 16 && b == 1)
708  return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
709  if (m == 16 && n == 16 && k == 32 && b == 1)
710  return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
711  }
712  if (chipset >= kGfx90a) {
713  if (m == 32 && n == 32 && k == 4 && b == 2)
714  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
715  if (m == 16 && n == 16 && k == 4 && b == 4)
716  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
717  if (m == 4 && n == 4 && k == 4 && b == 16)
718  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
719  if (m == 32 && n == 32 && k == 8 && b == 1)
720  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
721  if (m == 16 && n == 16 && k == 16 && b == 1)
722  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
723  }
724  if (m == 32 && n == 32 && k == 2 && b == 2)
725  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
726  if (m == 16 && n == 16 && k == 2 && b == 4)
727  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
728  if (m == 4 && n == 4 && k == 2 && b == 16)
729  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
730  if (m == 32 && n == 32 && k == 4 && b == 1)
731  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
732  if (m == 16 && n == 16 && k == 8 && b == 1)
733  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
734  }
735 
736  if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
737  if (chipset >= kGfx950) {
738  if (m == 32 && n == 32 && k == 32 && b == 1)
739  return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
740  if (m == 16 && n == 16 && k == 64 && b == 1)
741  return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
742  }
743  if (m == 32 && n == 32 && k == 4 && b == 2)
744  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
745  if (m == 16 && n == 16 && k == 4 && b == 4)
746  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
747  if (m == 4 && n == 4 && k == 4 && b == 16)
748  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
749  if (m == 32 && n == 32 && k == 8 && b == 1)
750  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
751  if (m == 16 && n == 16 && k == 16 && b == 1)
752  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
753  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
754  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
755  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
756  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
757  }
758 
759  if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
760  if (m == 16 && n == 16 && k == 4 && b == 1)
761  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
762  if (m == 4 && n == 4 && k == 4 && b == 4)
763  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
764  }
765 
766  if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
767  // Known to be correct because there are no scalar f8 instructions and
768  // because a length mismatch will have been caught by the verifier.
769  Type sourceBElem =
770  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
771  if (m == 16 && n == 16 && k == 32 && b == 1) {
772  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
773  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
774  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
775  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
776  }
777  if (m == 32 && n == 32 && k == 16 && b == 1) {
778  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
779  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
780  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
781  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
782  }
783  }
784 
785  if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
786  Type sourceBElem =
787  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
788  if (m == 16 && n == 16 && k == 32 && b == 1) {
789  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
790  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
791  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
792  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
793  }
794  if (m == 32 && n == 32 && k == 16 && b == 1) {
795  if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
796  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
797  if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
798  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
799  }
800  }
801 
802  return std::nullopt;
803 }
804 
805 static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
807  .Case([](Float8E4M3FNType) { return 0u; })
808  .Case([](Float8E5M2Type) { return 1u; })
809  .Case([](Float6E2M3FNType) { return 2u; })
810  .Case([](Float6E3M2FNType) { return 3u; })
811  .Case([](Float4E2M1FNType) { return 4u; })
812  .Default([](Type) { return std::nullopt; });
813 }
814 
815 /// If there is a scaled MFMA instruction for the input element types `aType`
816 /// and `bType`, output type `destType`, problem size M, N, K, and B (number of
817 /// blocks) on the given `chipset`, return a tuple consisting of the
818 /// OperationName of the intrinsic and the type codes that need to be passed to
819 /// that intrinsic. Note that this is also used to implement some un-scaled
820 /// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
821 /// MFMA with a scale of 0.
822 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
823 mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
824  uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
825  aType = getElementTypeOrSelf(aType);
826  bType = getElementTypeOrSelf(bType);
827  destType = getElementTypeOrSelf(destType);
828 
829  if (chipset < kGfx950)
830  return std::nullopt;
831  if (!isa<Float32Type>(destType))
832  return std::nullopt;
833 
834  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
835  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
836  if (!aTypeCode || !bTypeCode)
837  return std::nullopt;
838 
839  if (m == 32 && n == 32 && k == 64 && b == 1)
840  return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
841  *aTypeCode, *bTypeCode};
842  if (m == 16 && n == 16 && k == 128 && b == 1)
843  return std::tuple{
844  ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
845  *bTypeCode};
846 
847  return std::nullopt;
848 }
849 
850 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
851 mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
853  mfma.getSourceA().getType(), mfma.getSourceB().getType(),
854  mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
855  mfma.getBlocks(), chipset);
856 }
857 
858 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
859 mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
860  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
861  smfma.getSourceB().getType(),
862  smfma.getDestC().getType(), smfma.getM(),
863  smfma.getN(), smfma.getK(), 1u, chipset);
864 }
865 
866 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
867 /// if one exists. This includes checking to ensure the intrinsic is supported
868 /// on the architecture you are compiling for.
869 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
870  Chipset chipset) {
871  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
872  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
873  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
874  auto elemSourceType = sourceVectorType.getElementType();
875  auto elemBSourceType = sourceBVectorType.getElementType();
876  auto elemDestType = destVectorType.getElementType();
877 
878  if (elemSourceType.isF16() && elemDestType.isF32())
879  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
880  if (elemSourceType.isBF16() && elemDestType.isF32())
881  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
882  if (elemSourceType.isF16() && elemDestType.isF16())
883  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
884  if (elemSourceType.isBF16() && elemDestType.isBF16())
885  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
886  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
887  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
888  if (chipset.majorVersion == 11) {
889  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
890  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
891  }
892  if (chipset.majorVersion >= 12) {
893  if (isa<Float8E4M3FNType>(elemSourceType) &&
894  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
895  return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
896  if (isa<Float8E4M3FNType>(elemSourceType) &&
897  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
898  return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
899  if (isa<Float8E5M2Type>(elemSourceType) &&
900  isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
901  return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
902  if (isa<Float8E5M2Type>(elemSourceType) &&
903  isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
904  return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
905  if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
906  bool isWave64 = destVectorType.getNumElements() == 4;
907  // This is the ambiguous case. 8 inputs to the wave64 version means that
908  // we want the 16x16x32 version, but for wave32 they mean the short form.
909  bool has8Inputs = sourceVectorType.getNumElements() == 8;
910  if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
911  return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
912  return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
913  }
914  }
915  return std::nullopt;
916 }
917 
918 namespace {
919 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
920  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
921  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
922 
923  Chipset chipset;
924 
925  LogicalResult
926  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
927  ConversionPatternRewriter &rewriter) const override {
928  Location loc = op.getLoc();
929  Type outType = typeConverter->convertType(op.getDestD().getType());
930  Type intrinsicOutType = outType;
931  if (auto outVecType = dyn_cast<VectorType>(outType))
932  if (outVecType.getElementType().isBF16())
933  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
934 
935  if (chipset.majorVersion != 9 || chipset < kGfx908)
936  return op->emitOpError("MFMA only supported on gfx908+");
937  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
938  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
939  if (chipset < kGfx942)
940  return op.emitOpError("negation unsupported on older than gfx942");
941  getBlgpField |=
942  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
943  }
944  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
945  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
946  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
947  if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
948  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
949 
950  bool isScaled =
951  !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
952  if (isScaled &&
953  (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
954  return op.emitOpError(
955  "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
956  "be scaled as those fields are used for type information");
957  }
958 
959  StringRef intrinsicName =
960  isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
961  OperationState loweredOp(loc, intrinsicName);
962  loweredOp.addTypes(intrinsicOutType);
963  loweredOp.addOperands(
964  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
965  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
966  adaptor.getDestC()});
967  if (isScaled) {
968  Value zero = createI32Constant(rewriter, loc, 0);
969  auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
970  loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
971  createI32Constant(rewriter, loc, bTypeCode),
972  /*scale A byte=*/zero, /*scale A=*/zero,
973  /*scale B byte=*/zero, /*scale B=*/zero});
974  } else {
975  loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
976  createI32Constant(rewriter, loc, op.getAbid()),
977  createI32Constant(rewriter, loc, getBlgpField)});
978  };
979  Value lowered = rewriter.create(loweredOp)->getResult(0);
980  if (outType != intrinsicOutType)
981  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
982  rewriter.replaceOp(op, lowered);
983  return success();
984  }
985 };
986 
987 struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
988  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
989  : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
990 
991  Chipset chipset;
992 
993  LogicalResult
994  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
995  ConversionPatternRewriter &rewriter) const override {
996  Location loc = op.getLoc();
997  Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
998 
999  if (chipset.majorVersion != 9 || chipset < kGfx950)
1000  return op->emitOpError("scaled MFMA only supported on gfx908+");
1001  std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1002  maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1003  if (!maybeScaledIntrinsic.has_value())
1004  return op.emitOpError(
1005  "no intrinsic matching scaled MFMA size on given chipset");
1006 
1007  auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1008  OperationState loweredOp(loc, intrinsicName);
1009  loweredOp.addTypes(intrinsicOutType);
1010  loweredOp.addOperands(
1011  {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1012  convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1013  adaptor.getDestC()});
1014  Value scalesIdxA =
1015  createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1016  Value scalesIdxB =
1017  createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1018  loweredOp.addOperands(
1019  {createI32Constant(rewriter, loc, aTypeCode),
1020  createI32Constant(rewriter, loc, bTypeCode),
1021  /*scales idx A=*/scalesIdxA,
1022  /*scales A*/
1023  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1024  /*scales idx B=*/scalesIdxB,
1025  /*scales B*/
1026  castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1027  Value lowered = rewriter.create(loweredOp)->getResult(0);
1028  rewriter.replaceOp(op, lowered);
1029  return success();
1030  }
1031 };
1032 
1033 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1034  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1035  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1036 
1037  Chipset chipset;
1038 
1039  LogicalResult
1040  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1041  ConversionPatternRewriter &rewriter) const override {
1042  Location loc = op.getLoc();
1043  auto outType =
1044  typeConverter->convertType<VectorType>(op.getDestD().getType());
1045  if (!outType)
1046  return rewriter.notifyMatchFailure(op, "type conversion failed");
1047 
1048  if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1049  return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1050 
1051  // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1052  // need to bitcast bfloats to i16 and then bitcast them back.
1053  VectorType rawOutType = outType;
1054  if (outType.getElementType().isBF16())
1055  rawOutType = outType.clone(rewriter.getI16Type());
1056 
1057  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1058 
1059  if (!maybeIntrinsic.has_value())
1060  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1061 
1062  if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1063  return op.emitOpError("subwordOffset not supported on gfx12+");
1064 
1065  OperationState loweredOp(loc, *maybeIntrinsic);
1066  loweredOp.addTypes(rawOutType);
1067 
1068  SmallVector<Value, 4> operands;
1069  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1070  adaptor.getSourceA(), op.getSourceA(), operands);
1071  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1072  adaptor.getSourceB(), op.getSourceB(), operands);
1073  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1074  op.getSubwordOffset(), op.getClamp(), operands);
1075 
1076  loweredOp.addOperands(operands);
1077  Operation *lowered = rewriter.create(loweredOp);
1078 
1079  Operation *maybeCastBack = lowered;
1080  if (rawOutType != outType)
1081  maybeCastBack =
1082  rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
1083  rewriter.replaceOp(op, maybeCastBack->getResults());
1084 
1085  return success();
1086  }
1087 };
1088 
1089 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1090  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1091  : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1092 
1093  Chipset chipset;
1094 
1095  LogicalResult
1096  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1097  ConversionPatternRewriter &rewriter) const override {
1098  if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1099  return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1100 
1101  Location loc = op.getLoc();
1102 
1103  auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1104  auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1105 
1106  // TODO: instead of only transfering one element per thread, we could
1107  // augment it to transfer multiple elements per thread by issuing multiple
1108  // `global_load_lds` instructions.
1109  Type transferType = op.getTransferType();
1110  size_t loadWidth = [&]() -> size_t {
1111  if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1112  return transferVectorType.getNumElements() *
1113  (transferVectorType.getElementTypeBitWidth() / 8);
1114  }
1115  return transferType.getIntOrFloatBitWidth() / 8;
1116  }();
1117 
1118  // Currently only 1, 2, and 4 byte loads are supported.
1119  if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1120  return op.emitOpError("chipset unsupported element size");
1121 
1122  Value srcPtr =
1123  getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1124  (adaptor.getSrcIndices()));
1125  Value dstPtr =
1126  getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1127  (adaptor.getDstIndices()));
1128 
1129  rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1130  op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1131  /*offset=*/rewriter.getI32IntegerAttr(0),
1132  /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1133  ArrayAttr{});
1134 
1135  return success();
1136  }
1137 };
1138 
1139 namespace {
1140 struct ExtPackedFp8OpLowering final
1141  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1142  ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1143  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1144  chipset(chipset) {}
1145  Chipset chipset;
1146 
1147  LogicalResult
1148  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1149  ConversionPatternRewriter &rewriter) const override;
1150 };
1151 
1152 struct PackedTrunc2xFp8OpLowering final
1153  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1154  PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1155  Chipset chipset)
1156  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1157  chipset(chipset) {}
1158  Chipset chipset;
1159 
1160  LogicalResult
1161  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1162  ConversionPatternRewriter &rewriter) const override;
1163 };
1164 
1165 struct PackedStochRoundFp8OpLowering final
1166  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1167  PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1168  Chipset chipset)
1169  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1170  chipset(chipset) {}
1171  Chipset chipset;
1172 
1173  LogicalResult
1174  matchAndRewrite(PackedStochRoundFp8Op op,
1175  PackedStochRoundFp8OpAdaptor adaptor,
1176  ConversionPatternRewriter &rewriter) const override;
1177 };
1178 
1179 struct ScaledExtPackedOpLowering final
1180  : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1181  ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1182  : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1183  chipset(chipset) {}
1184  Chipset chipset;
1185 
1186  LogicalResult
1187  matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1188  ConversionPatternRewriter &rewriter) const override;
1189 };
1190 
1191 struct PackedScaledTruncOpLowering final
1192  : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1193  PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1194  Chipset chipset)
1195  : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1196  chipset(chipset) {}
1197  Chipset chipset;
1198 
1199  LogicalResult
1200  matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1201  ConversionPatternRewriter &rewriter) const override;
1202 };
1203 
1204 } // end namespace
1205 
1206 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1207  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1208  ConversionPatternRewriter &rewriter) const {
1209  Location loc = op.getLoc();
1210  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1211  return rewriter.notifyMatchFailure(
1212  loc, "Fp8 conversion instructions are not available on target "
1213  "architecture and their emulation is not implemented");
1214  Type v4i8 =
1215  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1216  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1217  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1218 
1219  Value source = adaptor.getSource();
1220  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1221  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1222  Type sourceElemType = getElementTypeOrSelf(op.getSource());
1223  // Extend to a v4i8
1224  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1225  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
1226  if (!sourceVecType) {
1227  longVec = rewriter.create<LLVM::InsertElementOp>(
1228  loc, longVec, source, createI32Constant(rewriter, loc, 0));
1229  } else {
1230  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1231  Value idx = createI32Constant(rewriter, loc, i);
1232  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1233  longVec =
1234  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1235  }
1236  }
1237  source = longVec;
1238  }
1239  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1240  if (resultVecType) {
1241  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1242  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1243  op.getIndex());
1244  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1245  rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1246  op.getIndex());
1247  }
1248  } else {
1249  if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1250  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1251  op.getIndex());
1252  } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1253  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1254  op.getIndex());
1255  }
1256  }
1257  return success();
1258 }
1259 
1260 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1261  ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1262  ConversionPatternRewriter &rewriter) const {
1263  Location loc = op.getLoc();
1264  if (chipset != kGfx950)
1265  return rewriter.notifyMatchFailure(
1266  loc, "Scaled fp conversion instructions are not available on target "
1267  "architecture and their emulation is not implemented");
1268  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1269 
1270  Value source = adaptor.getSource();
1271  Value scale = adaptor.getScale();
1272 
1273  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1274  Type sourceElemType = sourceVecType.getElementType();
1275  VectorType destVecType = cast<VectorType>(op.getResult().getType());
1276  Type destElemType = destVecType.getElementType();
1277 
1278  VectorType packedVecType;
1279  if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1280  VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1281  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1282  } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1283  VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1284  packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1285  } else {
1286  llvm_unreachable("invalid element type for scaled ext");
1287  }
1288 
1289  // Extend to a packedVectorType
1290  if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1291  Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
1292  if (!sourceVecType) {
1293  longVec = rewriter.create<LLVM::InsertElementOp>(
1294  loc, longVec, source, createI32Constant(rewriter, loc, 0));
1295  } else {
1296  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1297  Value idx = createI32Constant(rewriter, loc, i);
1298  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1299  longVec =
1300  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1301  }
1302  }
1303  source = longVec;
1304  }
1305  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1306 
1307  if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1308  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1309  op, destVecType, i32Source, scale, op.getIndex());
1310  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1311  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1312  op, destVecType, i32Source, scale, op.getIndex());
1313  else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1314  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1315  op, destVecType, i32Source, scale, op.getIndex());
1316  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1317  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1318  op, destVecType, i32Source, scale, op.getIndex());
1319  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1320  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1321  op, destVecType, i32Source, scale, op.getIndex());
1322  else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1323  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1324  op, destVecType, i32Source, scale, op.getIndex());
1325  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1326  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1327  op, destVecType, i32Source, scale, op.getIndex());
1328  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1329  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1330  op, destVecType, i32Source, scale, op.getIndex());
1331  else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1332  rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1333  op, destVecType, i32Source, scale, op.getIndex());
1334  else
1335  return failure();
1336 
1337  return success();
1338 }
1339 
1340 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1341  PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1342  ConversionPatternRewriter &rewriter) const {
1343  Location loc = op.getLoc();
1344  if (chipset != kGfx950)
1345  return rewriter.notifyMatchFailure(
1346  loc, "Scaled fp conversion instructions are not available on target "
1347  "architecture and their emulation is not implemented");
1348  Type v2i16 = getTypeConverter()->convertType(
1349  VectorType::get(2, rewriter.getI16Type()));
1350  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1351 
1352  Type resultType = op.getResult().getType();
1353  Type resultElemType = getElementTypeOrSelf(resultType);
1354  VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1355  Type sourceElemType = sourceVecType.getElementType();
1356 
1357  Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1358 
1359  Value source = adaptor.getSource();
1360  Value scale = adaptor.getScale();
1361  Value existing = adaptor.getExisting();
1362  if (existing)
1363  existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
1364  else
1365  existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
1366 
1367  if (sourceVecType.getNumElements() < 2) {
1368  Value c0 = createI32Constant(rewriter, loc, 0);
1369  Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1370  VectorType v2 = VectorType::get(2, sourceElemType);
1371  source = rewriter.create<LLVM::ZeroOp>(loc, v2);
1372  source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
1373  }
1374 
1375  Value sourceA, sourceB;
1376  if (sourceElemType.isF32()) {
1377  Value c0 = createI32Constant(rewriter, loc, 0);
1378  Value c1 = createI32Constant(rewriter, loc, 1);
1379  sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1380  sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
1381  }
1382 
1383  Value result;
1384  if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1385  result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
1386  loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1387  else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1388  result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
1389  loc, intResultType, existing, source, scale, op.getIndex());
1390  else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1391  result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
1392  loc, intResultType, existing, source, scale, op.getIndex());
1393  else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1394  result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
1395  loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1396  else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1397  result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
1398  loc, intResultType, existing, source, scale, op.getIndex());
1399  else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1400  result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
1401  loc, intResultType, existing, source, scale, op.getIndex());
1402  else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1403  result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
1404  loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1405  else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1406  result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
1407  loc, intResultType, existing, source, scale, op.getIndex());
1408  else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1409  result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
1410  loc, intResultType, existing, source, scale, op.getIndex());
1411  else
1412  return failure();
1413 
1414  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1415  op, getTypeConverter()->convertType(resultType), result);
1416  return success();
1417 }
1418 
1419 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1420  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1421  ConversionPatternRewriter &rewriter) const {
1422  Location loc = op.getLoc();
1423  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1424  return rewriter.notifyMatchFailure(
1425  loc, "Fp8 conversion instructions are not available on target "
1426  "architecture and their emulation is not implemented");
1427  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1428 
1429  Type resultType = op.getResult().getType();
1430  Type resultElemType = getElementTypeOrSelf(resultType);
1431 
1432  Value sourceA = adaptor.getSourceA();
1433  Value sourceB = adaptor.getSourceB();
1434  if (!sourceB)
1435  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
1436  Value existing = adaptor.getExisting();
1437  if (existing)
1438  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1439  else
1440  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1441 
1442  Value result;
1443  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1444  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1445  existing, op.getWordIndex());
1446  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1447  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1448  existing, op.getWordIndex());
1449 
1450  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1451  op, getTypeConverter()->convertType(resultType), result);
1452  return success();
1453 }
1454 
1455 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1456  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1457  ConversionPatternRewriter &rewriter) const {
1458  Location loc = op.getLoc();
1459  if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1460  return rewriter.notifyMatchFailure(
1461  loc, "Fp8 conversion instructions are not available on target "
1462  "architecture and their emulation is not implemented");
1463  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1464 
1465  Type resultType = op.getResult().getType();
1466  Type resultElemType = getElementTypeOrSelf(resultType);
1467 
1468  Value source = adaptor.getSource();
1469  Value stoch = adaptor.getStochiasticParam();
1470  Value existing = adaptor.getExisting();
1471  if (existing)
1472  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1473  else
1474  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1475 
1476  Value result;
1477  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1478  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
1479  loc, i32, source, stoch, existing, op.getStoreIndex());
1480  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1481  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
1482  loc, i32, source, stoch, existing, op.getStoreIndex());
1483 
1484  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1485  op, getTypeConverter()->convertType(resultType), result);
1486  return success();
1487 }
1488 
1489 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1490 // operation into the corresponding ROCDL instructions.
1491 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1492  AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1493  : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1494  Chipset chipset;
1495 
1496  LogicalResult
1497  matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1498  ConversionPatternRewriter &rewriter) const override {
1499 
1500  // Convert the source operand to the corresponding LLVM type
1501  Location loc = DppOp.getLoc();
1502  Value src = adaptor.getSrc();
1503  Value old = adaptor.getOld();
1504  Type srcType = src.getType();
1505  Type oldType = old.getType();
1506  Type llvmType = nullptr;
1507  if (srcType.getIntOrFloatBitWidth() < 32) {
1508  llvmType = rewriter.getI32Type();
1509  } else if (isa<FloatType>(srcType)) {
1510  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1511  ? rewriter.getF32Type()
1512  : rewriter.getF64Type();
1513  } else if (isa<IntegerType>(srcType)) {
1514  llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1515  ? rewriter.getI32Type()
1516  : rewriter.getI64Type();
1517  }
1518  auto llvmSrcIntType = typeConverter->convertType(
1519  rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1520 
1521  // If the source type is less of 32, use bitcast to convert it to i32.
1522  auto convertOperand = [&](Value operand, Type operandType) {
1523  if (operandType.getIntOrFloatBitWidth() <= 16) {
1524  if (llvm::isa<FloatType>(operandType)) {
1525  operand =
1526  rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1527  }
1528  auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1529  32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1530  Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
1531  operand = rewriter.create<LLVM::InsertElementOp>(
1532  loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
1533  operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
1534  }
1535  return operand;
1536  };
1537 
1538  src = convertOperand(src, srcType);
1539  old = convertOperand(old, oldType);
1540 
1541  // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1542  enum DppCtrl : unsigned {
1543  ROW_SHL0 = 0x100,
1544  ROW_SHR0 = 0x110,
1545  ROW_ROR0 = 0x120,
1546  WAVE_SHL1 = 0x130,
1547  WAVE_ROL1 = 0x134,
1548  WAVE_SHR1 = 0x138,
1549  WAVE_ROR1 = 0x13C,
1550  ROW_MIRROR = 0x140,
1551  ROW_HALF_MIRROR = 0x141,
1552  BCAST15 = 0x142,
1553  BCAST31 = 0x143,
1554  };
1555 
1556  auto kind = DppOp.getKind();
1557  auto permArgument = DppOp.getPermArgument();
1558  uint32_t DppCtrl = 0;
1559 
1560  switch (kind) {
1561 
1562  case DPPPerm::quad_perm:
1563  if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1564  int32_t i = 0;
1565  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1566  uint32_t num = elem.getInt();
1567  DppCtrl |= num << (i * 2);
1568  i++;
1569  }
1570  }
1571  break;
1572  case DPPPerm::row_shl:
1573  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1574  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1575  }
1576  break;
1577  case DPPPerm::row_shr:
1578  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1579  DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1580  }
1581  break;
1582  case DPPPerm::row_ror:
1583  if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1584  DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1585  }
1586  break;
1587  case DPPPerm::wave_shl:
1588  DppCtrl = DppCtrl::WAVE_SHL1;
1589  break;
1590  case DPPPerm::wave_shr:
1591  DppCtrl = DppCtrl::WAVE_SHR1;
1592  break;
1593  case DPPPerm::wave_rol:
1594  DppCtrl = DppCtrl::WAVE_ROL1;
1595  break;
1596  case DPPPerm::wave_ror:
1597  DppCtrl = DppCtrl::WAVE_ROR1;
1598  break;
1599  case DPPPerm::row_mirror:
1600  DppCtrl = DppCtrl::ROW_MIRROR;
1601  break;
1602  case DPPPerm::row_half_mirror:
1603  DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1604  break;
1605  case DPPPerm::row_bcast_15:
1606  DppCtrl = DppCtrl::BCAST15;
1607  break;
1608  case DPPPerm::row_bcast_31:
1609  DppCtrl = DppCtrl::BCAST31;
1610  break;
1611  }
1612 
1613  // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1614  // constants
1615  auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1616  auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1617  bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1618 
1619  // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1620  auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1621  loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1622 
1623  Value result = dppMovOp.getRes();
1624  if (srcType.getIntOrFloatBitWidth() < 32) {
1625  result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1626  if (!llvm::isa<IntegerType>(srcType)) {
1627  result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1628  }
1629  }
1630 
1631  // We are replacing the AMDGPU_DPPOp instruction with the new
1632  // ROCDL_DPPMovOp instruction
1633  rewriter.replaceOp(DppOp, ValueRange(result));
1634  return success();
1635  }
1636 };
1637 
1638 struct AMDGPUSwizzleBitModeLowering
1639  : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1641 
1642  LogicalResult
1643  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1644  ConversionPatternRewriter &rewriter) const override {
1645  Location loc = op.getLoc();
1646  Type i32 = rewriter.getI32Type();
1647  Value src = adaptor.getSrc();
1648  SmallVector<Value> decomposed =
1649  LLVM::decomposeValue(rewriter, loc, src, i32);
1650  unsigned andMask = op.getAndMask();
1651  unsigned orMask = op.getOrMask();
1652  unsigned xorMask = op.getXorMask();
1653 
1654  // bit 15 is 0 for the BitMode swizzle.
1655  // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1656  unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1657  Value maskValue = createI32Constant(rewriter, loc, mask);
1658  SmallVector<Value> swizzled;
1659  for (Value v : decomposed) {
1660  Value res =
1661  rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1662  swizzled.emplace_back(res);
1663  }
1664 
1665  Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1666  rewriter.replaceOp(op, result);
1667  return success();
1668  }
1669 };
1670 
1671 struct ConvertAMDGPUToROCDLPass
1672  : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1673  using Base::Base;
1674 
1675  void runOnOperation() override {
1676  MLIRContext *ctx = &getContext();
1677  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1678  if (failed(maybeChipset)) {
1679  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1680  return signalPassFailure();
1681  }
1682 
1684  LLVMTypeConverter converter(ctx);
1685  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1686  LLVMConversionTarget target(getContext());
1687  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1688  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1689  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1690  if (failed(applyPartialConversion(getOperation(), target,
1691  std::move(patterns))))
1692  signalPassFailure();
1693  }
1694 };
1695 } // namespace
1696 
1698  TypeConverter &typeConverter) {
1699  typeConverter.addTypeAttributeConversion(
1700  [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1702  MLIRContext *ctx = as.getContext();
1703  Type i64 = IntegerType::get(ctx, 64);
1704  switch (as.getValue()) {
1705  case amdgpu::AddressSpace::FatRawBuffer:
1706  return IntegerAttr::get(i64, 7);
1707  case amdgpu::AddressSpace::BufferRsrc:
1708  return IntegerAttr::get(i64, 8);
1709  case amdgpu::AddressSpace::FatStructuredBuffer:
1710  return IntegerAttr::get(i64, 9);
1711  }
1713  });
1714 }
1715 
1718  Chipset chipset) {
1720  patterns
1721  .add<FatRawBufferCastLowering,
1722  RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1723  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1724  RawBufferOpLowering<RawBufferAtomicFaddOp,
1725  ROCDL::RawPtrBufferAtomicFaddOp>,
1726  RawBufferOpLowering<RawBufferAtomicFmaxOp,
1727  ROCDL::RawPtrBufferAtomicFmaxOp>,
1728  RawBufferOpLowering<RawBufferAtomicSmaxOp,
1729  ROCDL::RawPtrBufferAtomicSmaxOp>,
1730  RawBufferOpLowering<RawBufferAtomicUminOp,
1731  ROCDL::RawPtrBufferAtomicUminOp>,
1732  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1733  ROCDL::RawPtrBufferAtomicCmpSwap>,
1734  AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1735  MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1736  ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1737  PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1738  PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1739  chipset);
1740  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
1741 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
@ None
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
IntegerType getI16Type()
Definition: Builders.cpp:63
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
FloatType getF32Type()
Definition: Builders.cpp:45
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:215
IntegerType getI64Type()
Definition: Builders.cpp:67
IntegerType getI32Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
IntegerType getI4Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:55
IntegerType getI8Type()
Definition: Builders.cpp:61
FloatType getF64Type()
Definition: Builders.cpp:47
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:195
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:201
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
The general result of a type attribute conversion callback, allowing for early termination.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:487
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:448
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:409
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14