MLIR  19.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include <optional>
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::amdgpu;
31 
33  Location loc, int32_t value) {
34  Type llvmI32 = rewriter.getI32Type();
35  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
36 }
37 
39  bool value) {
40  Type llvmI1 = rewriter.getI1Type();
41  return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
42 }
43 
44 namespace {
45 /// Define lowering patterns for raw buffer ops
46 template <typename GpuOp, typename Intrinsic>
47 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
48  RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
49  : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
50 
51  Chipset chipset;
52  static constexpr uint32_t maxVectorOpWidth = 128;
53 
55  matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
56  ConversionPatternRewriter &rewriter) const override {
57  Location loc = gpuOp.getLoc();
58  Value memref = adaptor.getMemref();
59  Value unconvertedMemref = gpuOp.getMemref();
60  MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
61 
62  if (chipset.majorVersion < 9)
63  return gpuOp.emitOpError("raw buffer ops require GCN or higher");
64 
65  Value storeData = adaptor.getODSOperands(0)[0];
66  if (storeData == memref) // no write component to this op
67  storeData = Value();
68  Type wantedDataType;
69  if (storeData)
70  wantedDataType = storeData.getType();
71  else
72  wantedDataType = gpuOp.getODSResults(0)[0].getType();
73 
74  Value atomicCmpData = Value();
75  // Operand index 1 of a load is the indices, trying to read them can crash.
76  if (storeData) {
77  Value maybeCmpData = adaptor.getODSOperands(1)[0];
78  if (maybeCmpData != memref)
79  atomicCmpData = maybeCmpData;
80  }
81 
82  Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
83 
84  Type i32 = rewriter.getI32Type();
85  Type llvmI32 = this->typeConverter->convertType(i32);
86  Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
87 
88  int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
89  Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
90 
91  // If we want to load a vector<NxT> with total size <= 32
92  // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
93  // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
94  // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
95  // so bitcast any floats to integers. On top of all this, cast bfloat
96  // (vectors) to i16 since the backend doesn't currently support bfloat on
97  // these operations.
98  Type llvmBufferValType = llvmWantedDataType;
99  if (wantedDataType.isBF16())
100  llvmBufferValType = rewriter.getI16Type();
101  if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
102  if (wantedVecType.getElementType().isBF16())
103  llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
104  if (atomicCmpData) {
105  if (isa<VectorType>(wantedDataType))
106  return gpuOp.emitOpError("vector compare-and-swap does not exist");
107  if (auto floatType = dyn_cast<FloatType>(wantedDataType))
108  llvmBufferValType = this->getTypeConverter()->convertType(
109  rewriter.getIntegerType(floatType.getWidth()));
110  }
111  if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
112  uint32_t elemBits = dataVector.getElementTypeBitWidth();
113  uint32_t totalBits = elemBits * dataVector.getNumElements();
114  if (totalBits > maxVectorOpWidth)
115  return gpuOp.emitOpError(
116  "Total width of loads or stores must be no more than " +
117  Twine(maxVectorOpWidth) + " bits, but we call for " +
118  Twine(totalBits) +
119  " bits. This should've been caught in validation");
120  if (elemBits < 32) {
121  if (totalBits > 32) {
122  if (totalBits % 32 != 0)
123  return gpuOp.emitOpError("Load or store of more than 32-bits that "
124  "doesn't fit into words. Can't happen\n");
125  llvmBufferValType = this->typeConverter->convertType(
126  VectorType::get(totalBits / 32, i32));
127  } else {
128  llvmBufferValType = this->typeConverter->convertType(
129  rewriter.getIntegerType(totalBits));
130  }
131  }
132  }
133 
135  if (storeData) {
136  if (llvmBufferValType != llvmWantedDataType) {
137  Value castForStore =
138  rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139  args.push_back(castForStore);
140  } else {
141  args.push_back(storeData);
142  }
143  }
144 
145  if (atomicCmpData) {
146  if (llvmBufferValType != llvmWantedDataType) {
147  Value castForCmp = rewriter.create<LLVM::BitcastOp>(
148  loc, llvmBufferValType, atomicCmpData);
149  args.push_back(castForCmp);
150  } else {
151  args.push_back(atomicCmpData);
152  }
153  }
154 
155  // Construct buffer descriptor from memref, attributes
156  int64_t offset = 0;
157  SmallVector<int64_t, 5> strides;
158  if (failed(getStridesAndOffset(memrefType, strides, offset)))
159  return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
160 
161  MemRefDescriptor memrefDescriptor(memref);
162 
163  Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
164  // The stride value is always 0 for raw buffers. This also disables
165  // swizling.
166  Value stride = rewriter.create<LLVM::ConstantOp>(
167  loc, llvmI16, rewriter.getI16IntegerAttr(0));
168  Value numRecords;
169  if (memrefType.hasStaticShape()) {
170  numRecords = createI32Constant(
171  rewriter, loc,
172  static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
173  } else {
174  Value maxIndex;
175  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176  Value size = memrefDescriptor.size(rewriter, loc, i);
177  Value stride = memrefDescriptor.stride(rewriter, loc, i);
178  stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
179  Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
180  maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
181  maxThisDim)
182  : maxThisDim;
183  }
184  numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
185  }
186 
187  // Flag word:
188  // bits 0-11: dst sel, ignored by these intrinsics
189  // bits 12-14: data format (ignored, must be nonzero, 7=float)
190  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
191  // bit 19: In nested heap (0 here)
192  // bit 20: Behavior on unmap (0 means "return 0 / ignore")
193  // bits 21-22: Index stride for swizzles (N/A)
194  // bit 23: Add thread ID (0)
195  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
196  // bits 25-26: Reserved (0)
197  // bit 27: Buffer is non-volatile (CDNA only)
198  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
199  // none, 3 = either swizzles or testing against offset field) RDNA only
200  // bits 30-31: Type (must be 0)
201  uint32_t flags = (7 << 12) | (4 << 15);
202  if (chipset.majorVersion >= 10) {
203  flags |= (1 << 24);
204  uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205  flags |= (oob << 28);
206  }
207  Value flagsConst = createI32Constant(rewriter, loc, flags);
208  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
209  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
210  loc, rsrcType, ptr, stride, numRecords, flagsConst);
211  args.push_back(resource);
212 
213  // Indexing (voffset)
214  Value voffset = createI32Constant(rewriter, loc, 0);
215  for (auto pair : llvm::enumerate(adaptor.getIndices())) {
216  size_t i = pair.index();
217  Value index = pair.value();
218  Value strideOp;
219  if (ShapedType::isDynamic(strides[i])) {
220  strideOp = rewriter.create<LLVM::MulOp>(
221  loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
222  } else {
223  strideOp =
224  createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
225  }
226  index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
227  voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
228  }
229  if (adaptor.getIndexOffset()) {
230  int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
231  Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
232  voffset =
233  voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
234  : extraOffsetConst;
235  }
236  args.push_back(voffset);
237 
238  Value sgprOffset = adaptor.getSgprOffset();
239  if (!sgprOffset)
240  sgprOffset = createI32Constant(rewriter, loc, 0);
241  if (ShapedType::isDynamic(offset))
242  sgprOffset = rewriter.create<LLVM::AddOp>(
243  loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
244  else if (offset > 0)
245  sgprOffset = rewriter.create<LLVM::AddOp>(
246  loc, sgprOffset, createI32Constant(rewriter, loc, offset));
247  args.push_back(sgprOffset);
248 
249  // bit 0: GLC = 0 (atomics drop value, less coherency)
250  // bits 1-2: SLC, DLC = 0 (similarly)
251  // bit 3: swizzled (0 for raw)
252  args.push_back(createI32Constant(rewriter, loc, 0));
253 
254  llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
255  llvmBufferValType);
256  Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
258  if (lowered->getNumResults() == 1) {
259  Value replacement = lowered->getResult(0);
260  if (llvmBufferValType != llvmWantedDataType) {
261  replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
262  replacement);
263  }
264  rewriter.replaceOp(gpuOp, replacement);
265  } else {
266  rewriter.eraseOp(gpuOp);
267  }
268  return success();
269  }
270 };
271 
272 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
273  LDSBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset)
274  : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
275 
276  Chipset chipset;
277 
279  matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280  ConversionPatternRewriter &rewriter) const override {
281  bool requiresInlineAsm =
282  chipset.majorVersion < 9 ||
283  (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) ||
284  (chipset.majorVersion == 11);
285 
286  if (requiresInlineAsm) {
287  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
288  LLVM::AsmDialect::AD_ATT);
289  const char *asmStr =
290  ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
291  const char *constraints = "";
292  rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
293  op,
294  /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
295  /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
296  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
297  /*operand_attrs=*/ArrayAttr());
298  return success();
299  }
300  constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
301  constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
302  // Left in place in case someone disables the inline ASM path or future
303  // chipsets use the same bit pattern.
304  constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
305 
306  int32_t ldsOnlyBits;
307  if (chipset.majorVersion == 11)
308  ldsOnlyBits = ldsOnlyBitsGfx11;
309  else if (chipset.majorVersion == 10)
310  ldsOnlyBits = ldsOnlyBitsGfx10;
311  else if (chipset.majorVersion <= 9)
312  ldsOnlyBits = ldsOnlyBitsGfx6789;
313  else
314  return op.emitOpError(
315  "don't know how to lower this for chipset major version")
316  << chipset.majorVersion;
317 
318  Location loc = op->getLoc();
319  rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
320  rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
321  return success();
322  }
323 };
324 } // namespace
325 
326 /// If `input` is a vector of bytes, concatentate those bytes in little-endian
327 /// order to form a single integer of size 8 * [vector length]. This works
328 /// around a wart in the AMDGPU intrinsics where operations that logically take
329 /// vectors of bytes instead integers. Since we do not want to expose this
330 /// implementation detail to MLIR, we correct for it here.
331 ///
332 /// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
333 /// MFMA intrinsics pre-date the bfloat type.
335  Location loc, Value input) {
336  Type inputType = input.getType();
337  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
338  if (vectorType.getElementType().isBF16())
339  return rewriter.create<LLVM::BitcastOp>(
340  loc, vectorType.clone(rewriter.getI16Type()), input);
341 
342  if (!vectorType.getElementType().isInteger(8))
343  return input;
344  int64_t numBytes = vectorType.getNumElements();
345  Type destType = rewriter.getIntegerType(numBytes * 8);
346  Value result = rewriter.create<LLVM::ConstantOp>(
347  loc, destType, rewriter.getIntegerAttr(destType, 0));
348  for (int64_t i = 0; i < numBytes; ++i) {
349  Value idxConst = createI32Constant(rewriter, loc, i);
350  Value element =
351  rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
352  Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
353  Value shiftConst = rewriter.create<LLVM::ConstantOp>(
354  loc, destType, rewriter.getIntegerAttr(destType, i * 8));
355  Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
356  result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
357  }
358  return result;
359  }
360  return input;
361 }
362 
363 /// Push an input operand. If it is a float type, nothing to do. If it is
364 /// an integer type, then we need to also push its signdness (1 for signed, 0
365 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
366 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
367 /// of bfloat support in the WMMA intrinsics themselves.
369  Location loc,
370  const TypeConverter *typeConverter,
371  bool isUnsigned, Value llvmInput,
372  SmallVector<Value, 4> &operands) {
373  Type inputType = llvmInput.getType();
374  auto vectorType = dyn_cast<VectorType>(inputType);
375  Type elemType = vectorType.getElementType();
376 
377  if (elemType.isBF16())
378  llvmInput = rewriter.create<LLVM::BitcastOp>(
379  loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
380  if (!elemType.isInteger(8)) {
381  operands.push_back(llvmInput);
382  return;
383  }
384 
385  int64_t numBytes = vectorType.getNumElements();
386  Type i32 = rewriter.getI32Type();
387  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
388  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
389 
390  Value result = rewriter.createOrFold<LLVM::BitcastOp>(
391  loc, llvmVectorType32bits, llvmInput);
392 
393  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
394  bool localIsUnsigned = isUnsigned;
395  if (elemType.isUnsignedInteger(8)) {
396  localIsUnsigned = true;
397  } else if (elemType.isSignedInteger(8)) {
398  localIsUnsigned = false;
399  }
400  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
401  operands.push_back(sign);
402  operands.push_back(result);
403 }
404 
405 /// Push the output operand. For many cases this is only pushing the output in
406 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
407 /// since the same numbers of VGPRs is used, we need to decide if to store the
408 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
409 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
410 /// be stored it in the upper part
412  Location loc,
413  const TypeConverter *typeConverter,
414  Value output, int32_t subwordOffset,
415  bool clamp, SmallVector<Value, 4> &operands) {
416  Type inputType = output.getType();
417  auto vectorType = dyn_cast<VectorType>(inputType);
418  Type elemType = vectorType.getElementType();
419  if (elemType.isBF16())
420  output = rewriter.create<LLVM::BitcastOp>(
421  loc, vectorType.clone(rewriter.getI16Type()), output);
422  operands.push_back(output);
423  if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
424  operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
425  } else if (elemType.isInteger(32)) {
426  operands.push_back(createI1Constant(rewriter, loc, clamp));
427  }
428 }
429 
430 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
431 /// if one exists. This includes checking to ensure the intrinsic is supported
432 /// on the architecture you are compiling for.
433 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
434  Chipset chipset) {
435  uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
436  b = mfma.getBlocks();
437  Type sourceElem = mfma.getSourceA().getType();
438  if (auto sourceType = dyn_cast<VectorType>(sourceElem))
439  sourceElem = sourceType.getElementType();
440  Type destElem = mfma.getDestC().getType();
441  if (auto destType = dyn_cast<VectorType>(destElem))
442  destElem = destType.getElementType();
443 
444  if (sourceElem.isF32() && destElem.isF32()) {
445  if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
446  if (m == 32 && n == 32 && k == 4 && b == 1)
447  return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
448  if (m == 16 && n == 16 && k == 8 && b == 1)
449  return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
450  }
451  if (m == 32 && n == 32 && k == 1 && b == 2)
452  return ROCDL::mfma_f32_32x32x1f32::getOperationName();
453  if (m == 16 && n == 16 && k == 1 && b == 4)
454  return ROCDL::mfma_f32_16x16x1f32::getOperationName();
455  if (m == 4 && n == 4 && k == 1 && b == 16)
456  return ROCDL::mfma_f32_4x4x1f32::getOperationName();
457  if (m == 32 && n == 32 && k == 2 && b == 1)
458  return ROCDL::mfma_f32_32x32x2f32::getOperationName();
459  if (m == 16 && n == 16 && k == 4 && b == 1)
460  return ROCDL::mfma_f32_16x16x4f32::getOperationName();
461  }
462 
463  if (sourceElem.isF16() && destElem.isF32()) {
464  if (m == 32 && n == 32 && k == 4 && b == 2)
465  return ROCDL::mfma_f32_32x32x4f16::getOperationName();
466  if (m == 16 && n == 16 && k == 4 && b == 4)
467  return ROCDL::mfma_f32_16x16x4f16::getOperationName();
468  if (m == 4 && n == 4 && k == 4 && b == 16)
469  return ROCDL::mfma_f32_4x4x4f16::getOperationName();
470  if (m == 32 && n == 32 && k == 8 && b == 1)
471  return ROCDL::mfma_f32_32x32x8f16::getOperationName();
472  if (m == 16 && n == 16 && k == 16 && b == 1)
473  return ROCDL::mfma_f32_16x16x16f16::getOperationName();
474  }
475 
476  if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
477  if (m == 32 && n == 32 && k == 4 && b == 2)
478  return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
479  if (m == 16 && n == 16 && k == 4 && b == 4)
480  return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
481  if (m == 4 && n == 4 && k == 4 && b == 16)
482  return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
483  if (m == 32 && n == 32 && k == 8 && b == 1)
484  return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
485  if (m == 16 && n == 16 && k == 16 && b == 1)
486  return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
487  }
488 
489  if (sourceElem.isBF16() && destElem.isF32()) {
490  if (m == 32 && n == 32 && k == 2 && b == 2)
491  return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
492  if (m == 16 && n == 16 && k == 2 && b == 4)
493  return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
494  if (m == 4 && n == 4 && k == 2 && b == 16)
495  return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
496  if (m == 32 && n == 32 && k == 4 && b == 1)
497  return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
498  if (m == 16 && n == 16 && k == 8 && b == 1)
499  return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
500  }
501 
502  if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
503  if (m == 32 && n == 32 && k == 4 && b == 2)
504  return ROCDL::mfma_i32_32x32x4i8::getOperationName();
505  if (m == 16 && n == 16 && k == 4 && b == 4)
506  return ROCDL::mfma_i32_16x16x4i8::getOperationName();
507  if (m == 4 && n == 4 && k == 4 && b == 16)
508  return ROCDL::mfma_i32_4x4x4i8::getOperationName();
509  if (m == 32 && n == 32 && k == 8 && b == 1)
510  return ROCDL::mfma_i32_32x32x8i8::getOperationName();
511  if (m == 16 && n == 16 && k == 16 && b == 1)
512  return ROCDL::mfma_i32_16x16x16i8::getOperationName();
513  if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
514  return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
515  if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
516  return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
517  }
518 
519  if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
520  if (m == 16 && n == 16 && k == 4 && b == 1)
521  return ROCDL::mfma_f64_16x16x4f64::getOperationName();
522  if (m == 4 && n == 4 && k == 4 && b == 4)
523  return ROCDL::mfma_f64_4x4x4f64::getOperationName();
524  }
525 
526  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
527  chipset.minorVersion >= 0x40) {
528  // Known to be correct because there are no scalar f8 instructions and
529  // because a length mismatch will have been caught by the verifier.
530  Type sourceBElem =
531  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
532  if (m == 16 && n == 16 && k == 32 && b == 1) {
533  if (sourceBElem.isFloat8E5M2FNUZ())
534  return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
535  if (sourceBElem.isFloat8E4M3FNUZ())
536  return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
537  }
538  if (m == 32 && n == 32 && k == 16 && b == 1) {
539  if (sourceBElem.isFloat8E5M2FNUZ())
540  return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
541  if (sourceBElem.isFloat8E4M3FNUZ())
542  return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
543  }
544  }
545 
546  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
547  chipset.minorVersion >= 0x40) {
548  Type sourceBElem =
549  cast<VectorType>(mfma.getSourceB().getType()).getElementType();
550  if (m == 16 && n == 16 && k == 32 && b == 1) {
551  if (sourceBElem.isFloat8E5M2FNUZ())
552  return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
553  if (sourceBElem.isFloat8E4M3FNUZ())
554  return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
555  }
556  if (m == 32 && n == 32 && k == 16 && b == 1) {
557  if (sourceBElem.isFloat8E5M2FNUZ())
558  return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
559  if (sourceBElem.isFloat8E4M3FNUZ())
560  return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
561  }
562  }
563 
564  return std::nullopt;
565 }
566 
567 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
568 /// if one exists. This includes checking to ensure the intrinsic is supported
569 /// on the architecture you are compiling for.
570 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
571  Chipset chipset) {
572  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
573  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
574  auto elemSourceType = sourceVectorType.getElementType();
575  auto elemDestType = destVectorType.getElementType();
576 
577  if (elemSourceType.isF16() && elemDestType.isF32()) {
578  return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
579  }
580  if (elemSourceType.isBF16() && elemDestType.isF32()) {
581  return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
582  } else if (elemSourceType.isF16() && elemDestType.isF16()) {
583  return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
584  } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
585  return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
586  } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
587  return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
588  }
589  return std::nullopt;
590 }
591 
592 namespace {
593 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
594  MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
595  : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
596 
597  Chipset chipset;
598 
600  matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
601  ConversionPatternRewriter &rewriter) const override {
602  Location loc = op.getLoc();
603  Type outType = typeConverter->convertType(op.getDestD().getType());
604  Type intrinsicOutType = outType;
605  if (auto outVecType = dyn_cast<VectorType>(outType))
606  if (outVecType.getElementType().isBF16())
607  intrinsicOutType = outVecType.clone(rewriter.getI16Type());
608 
609  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
610  return op->emitOpError("MFMA only supported on gfx908+");
611  uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
612  if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
613  if (chipset.minorVersion < 0x40)
614  return op.emitOpError("negation unsupported on older than gfx840");
615  getBlgpField |=
616  op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
617  }
618  std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
619  if (!maybeIntrinsic.has_value())
620  return op.emitOpError("no intrinsic matching MFMA size on given chipset");
621  OperationState loweredOp(loc, *maybeIntrinsic);
622  loweredOp.addTypes(intrinsicOutType);
623  loweredOp.addOperands(
624  {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
625  mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
626  adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
627  createI32Constant(rewriter, loc, op.getAbid()),
628  createI32Constant(rewriter, loc, getBlgpField)});
629  Value lowered = rewriter.create(loweredOp)->getResult(0);
630  if (outType != intrinsicOutType)
631  lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
632  rewriter.replaceOp(op, lowered);
633  return success();
634  }
635 };
636 
637 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
638  WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
639  : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
640 
641  Chipset chipset;
642 
644  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
645  ConversionPatternRewriter &rewriter) const override {
646  Location loc = op.getLoc();
647  Type outType = typeConverter->convertType(op.getDestD().getType());
648 
649  if (chipset.majorVersion != 11)
650  return op->emitOpError("WMMA only supported on gfx11");
651 
652  std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
653 
654  if (!maybeIntrinsic.has_value())
655  return op.emitOpError("no intrinsic matching WMMA on the given chipset");
656 
657  OperationState loweredOp(loc, *maybeIntrinsic);
658  loweredOp.addTypes(outType);
659 
660  SmallVector<Value, 4> operands;
661  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
662  adaptor.getSourceA(), operands);
663  wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
664  adaptor.getSourceB(), operands);
665  wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
666  op.getSubwordOffset(), op.getClamp(), operands);
667 
668  loweredOp.addOperands(operands);
669  Operation *lowered = rewriter.create(loweredOp);
670  rewriter.replaceOp(op, lowered->getResults());
671 
672  return success();
673  }
674 };
675 
676 namespace {
677 struct ExtPackedFp8OpLowering final
678  : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
679  ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
680  : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
681  chipset(chipset) {}
682  Chipset chipset;
683 
685  matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
686  ConversionPatternRewriter &rewriter) const override;
687 };
688 
689 struct PackedTrunc2xFp8OpLowering final
690  : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
691  PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
692  : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
693  chipset(chipset) {}
694  Chipset chipset;
695 
697  matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
698  ConversionPatternRewriter &rewriter) const override;
699 };
700 
701 struct PackedStochRoundFp8OpLowering final
702  : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
703  PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
704  : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
705  chipset(chipset) {}
706  Chipset chipset;
707 
709  matchAndRewrite(PackedStochRoundFp8Op op,
710  PackedStochRoundFp8OpAdaptor adaptor,
711  ConversionPatternRewriter &rewriter) const override;
712 };
713 } // end namespace
714 
715 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
716  ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
717  ConversionPatternRewriter &rewriter) const {
718  Location loc = op.getLoc();
719  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
720  return rewriter.notifyMatchFailure(
721  loc, "Fp8 conversion instructions are not available on target "
722  "architecture and their emulation is not implemented");
723  Type v4i8 =
724  getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
725  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
726  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
727 
728  Value source = adaptor.getSource();
729  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
730  Type sourceElemType = getElementTypeOrSelf(op.getSource());
731  // Extend to a v4i8
732  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
733  Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
734  if (!sourceVecType) {
735  longVec = rewriter.create<LLVM::InsertElementOp>(
736  loc, longVec, source, createI32Constant(rewriter, loc, 0));
737  } else {
738  for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
739  Value idx = createI32Constant(rewriter, loc, i);
740  Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
741  longVec =
742  rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
743  }
744  }
745  source = longVec;
746  }
747  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
748  Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
749  if (sourceElemType.isFloat8E5M2FNUZ()) {
750  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
751  wordSel);
752  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
753  rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
754  wordSel);
755  }
756  return success();
757 }
758 
759 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
760  PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
761  ConversionPatternRewriter &rewriter) const {
762  Location loc = op.getLoc();
763  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
764  return rewriter.notifyMatchFailure(
765  loc, "Fp8 conversion instructions are not available on target "
766  "architecture and their emulation is not implemented");
767  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
768 
769  Type resultType = op.getResult().getType();
770  Type resultElemType = getElementTypeOrSelf(resultType);
771 
772  Value sourceA = adaptor.getSourceA();
773  Value sourceB = adaptor.getSourceB();
774  if (!sourceB)
775  sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
776  Value existing = adaptor.getExisting();
777  if (existing)
778  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
779  else
780  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
781  Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
782 
783  Value result;
784  if (resultElemType.isFloat8E5M2FNUZ())
785  result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
786  existing, wordSel);
787  else if (resultElemType.isFloat8E4M3FNUZ())
788  result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
789  existing, wordSel);
790 
791  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
792  op, getTypeConverter()->convertType(resultType), result);
793  return success();
794 }
795 
796 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
797  PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
798  ConversionPatternRewriter &rewriter) const {
799  Location loc = op.getLoc();
800  if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
801  return rewriter.notifyMatchFailure(
802  loc, "Fp8 conversion instructions are not available on target "
803  "architecture and their emulation is not implemented");
804  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
805 
806  Type resultType = op.getResult().getType();
807  Type resultElemType = getElementTypeOrSelf(resultType);
808 
809  Value source = adaptor.getSource();
810  Value stoch = adaptor.getStochiasticParam();
811  Value existing = adaptor.getExisting();
812  if (existing)
813  existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
814  else
815  existing = rewriter.create<LLVM::UndefOp>(loc, i32);
816  Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
817 
818  Value result;
819  if (resultElemType.isFloat8E5M2FNUZ())
820  result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
821  existing, byteSel);
822  else if (resultElemType.isFloat8E4M3FNUZ())
823  result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
824  existing, byteSel);
825 
826  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
827  op, getTypeConverter()->convertType(resultType), result);
828  return success();
829 }
830 
831 struct ConvertAMDGPUToROCDLPass
832  : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
833  ConvertAMDGPUToROCDLPass() = default;
834 
835  void runOnOperation() override {
836  MLIRContext *ctx = &getContext();
837  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
838  if (failed(maybeChipset)) {
839  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
840  return signalPassFailure();
841  }
842 
843  RewritePatternSet patterns(ctx);
844  LLVMTypeConverter converter(ctx);
845  populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
847  target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
848  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
849  target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
850  if (failed(applyPartialConversion(getOperation(), target,
851  std::move(patterns))))
852  signalPassFailure();
853  }
854 };
855 } // namespace
856 
858  RewritePatternSet &patterns,
859  Chipset chipset) {
860  converter.addConversion([](BFloat16Type t) -> Type {
861  return IntegerType::get(t.getContext(), 16);
862  });
863  converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
864  if (!t.getElementType().isBF16())
865  return std::nullopt;
866  return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
867  });
868 
869  patterns
870  .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
871  RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
872  RawBufferOpLowering<RawBufferAtomicFaddOp,
873  ROCDL::RawPtrBufferAtomicFaddOp>,
874  RawBufferOpLowering<RawBufferAtomicFmaxOp,
875  ROCDL::RawPtrBufferAtomicFmaxOp>,
876  RawBufferOpLowering<RawBufferAtomicSmaxOp,
877  ROCDL::RawPtrBufferAtomicSmaxOp>,
878  RawBufferOpLowering<RawBufferAtomicUminOp,
879  ROCDL::RawPtrBufferAtomicUminOp>,
880  RawBufferOpLowering<RawBufferAtomicCmpswapOp,
881  ROCDL::RawPtrBufferAtomicCmpSwap>,
882  LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering,
883  ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
884  PackedStochRoundFp8OpLowering>(converter, chipset);
885 }
886 
887 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
888  return std::make_unique<ConvertAMDGPUToROCDLPass>();
889 }
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input)
If input is a vector of bytes, concatentate those bytes in little-endian order to form a single integ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
IntegerType getI16Type()
Definition: Builders.cpp:81
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerAttr getI16IntegerAttr(int16_t value)
Definition: Builders.cpp:230
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IntegerType getI8Type()
Definition: Builders.cpp:79
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isF32() const
Definition: Types.cpp:51
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:42
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isF16() const
Definition: Types.cpp:49
bool isBF16() const
Definition: Types.cpp:48
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:39
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static FailureOr< Chipset > parse(StringRef name)
Definition: Chipset.cpp:16
unsigned majorVersion
Definition: Chipset.h:21
unsigned minorVersion
Definition: Chipset.h:22