MLIR  21.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
23 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 #include <limits>
30 #include <optional>
31 
32 using namespace mlir;
33 using namespace mlir::amdgpu;
34 
35 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
36 
37 void AMDGPUDialect::initialize() {
38  addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
41  >();
42  addAttributes<
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
45  >();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // 8-bit float ops
50 //===----------------------------------------------------------------------===//
51 LogicalResult PackedTrunc2xFp8Op::verify() {
52  if (getExisting() && getExisting().getType() != getResult().getType())
53  return emitOpError("existing values must have same type as result");
54  return success();
55 }
56 
57 LogicalResult PackedStochRoundFp8Op::verify() {
58  if (getExisting() && getExisting().getType() != getResult().getType())
59  return emitOpError("existing values must have same type as result");
60  return success();
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // FatRawBuferCastOp
65 //===----------------------------------------------------------------------===//
66 
67 /// Convert the type `source` to one with the same sizes and strides - and
68 /// offset, unless `stripOffset` is true, in which case the offset is reset to
69 /// 0, if the offset should be reset but the layout of `source` isn't either the
70 /// identity layout or a strided layout, this function fails.
71 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
72  bool resetOffset) {
73  MLIRContext *ctx = source.getContext();
74  MemRefType::Builder mb(source);
75  mb.setMemorySpace(
76  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
77  MemRefLayoutAttrInterface layout = source.getLayout();
78  if (resetOffset && !layout.isIdentity()) {
79  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
80  if (!stridedLayout)
81  return failure();
82  mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
83  }
84  return (MemRefType)(mb);
85 }
86 
87 LogicalResult FatRawBufferCastOp::inferReturnTypes(
88  MLIRContext *context, std::optional<Location> location, ValueRange operands,
89  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
90  SmallVectorImpl<Type> &inferredReturnTypes) {
91  Adaptor adaptor(operands, attributes, properties, regions);
92  auto sourceType =
93  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
94  if (!sourceType)
95  return failure();
96  FailureOr<MemRefType> resultType =
97  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
98  if (failed(resultType))
99  return failure();
100  inferredReturnTypes = SmallVector<Type>{*resultType};
101  return success();
102 }
103 
104 LogicalResult FatRawBufferCastOp::verify() {
105  FailureOr<MemRefType> expectedResultType =
106  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
107  if (failed(expectedResultType))
108  return emitOpError("source type ")
109  << getSource().getType() << " can't have its offset reset";
110  if (getResult().getType() != *expectedResultType)
111  return emitOpError("expected result type to be ")
112  << *expectedResultType << " but got " << getResult().getType();
113  return success();
114 }
115 
116 static bool hasGlobalMemorySpace(Attribute memorySpace) {
117  if (!memorySpace)
118  return true;
119  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
120  return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
121  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
122  return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
123  return false;
124 }
125 
126 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
127  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
128  return intMemorySpace.getInt() == 3;
129  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
130  return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
131  return false;
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // RawBuffer*Op
136 //===----------------------------------------------------------------------===//
137 template <typename T>
138 static LogicalResult verifyRawBufferOp(T &op) {
139  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
140  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
141 
142  if (!isGlobal)
143  return op.emitOpError(
144  "Buffer ops must operate on a memref in global memory");
145  if (!bufferType.hasRank())
146  return op.emitOpError(
147  "Cannot meaningfully buffer_store to an unranked memref");
148  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
149  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
150  " indices to memref");
151  return success();
152 }
153 
154 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
155 
156 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
157 
158 LogicalResult RawBufferAtomicFaddOp::verify() {
159  return verifyRawBufferOp(*this);
160 }
161 
162 LogicalResult RawBufferAtomicFmaxOp::verify() {
163  return verifyRawBufferOp(*this);
164 }
165 
166 LogicalResult RawBufferAtomicSmaxOp::verify() {
167  return verifyRawBufferOp(*this);
168 }
169 
170 LogicalResult RawBufferAtomicUminOp::verify() {
171  return verifyRawBufferOp(*this);
172 }
173 
174 LogicalResult RawBufferAtomicCmpswapOp::verify() {
175  return verifyRawBufferOp(*this);
176 }
177 
178 static std::optional<uint32_t> getConstantUint32(Value v) {
179  APInt cst;
180  if (!v.getType().isInteger(32))
181  return std::nullopt;
182  if (matchPattern(v, m_ConstantInt(&cst)))
183  return cst.getZExtValue();
184  return std::nullopt;
185 }
186 
187 template <typename OpType>
188 static bool staticallyOutOfBounds(OpType op) {
189  if (!op.getBoundsCheck())
190  return false;
191  MemRefType bufferType = op.getMemref().getType();
192  if (!bufferType.hasStaticShape())
193  return false;
194  int64_t offset;
195  SmallVector<int64_t> strides;
196  if (failed(bufferType.getStridesAndOffset(strides, offset)))
197  return false;
198  int64_t result = offset + op.getIndexOffset().value_or(0);
199  if (op.getSgprOffset()) {
200  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
201  if (!sgprOffset)
202  return false;
203  result += *sgprOffset;
204  }
205  if (strides.size() != op.getIndices().size())
206  return false;
207  int64_t indexVal = 0;
208  for (auto pair : llvm::zip(strides, op.getIndices())) {
209  int64_t stride = std::get<0>(pair);
210  Value idx = std::get<1>(pair);
211  std::optional<uint32_t> idxVal = getConstantUint32(idx);
212  if (!idxVal)
213  return false;
214  indexVal += stride * *idxVal;
215  }
216  result += indexVal;
217  if (result > std::numeric_limits<uint32_t>::max())
218  // Overflow means don't drop
219  return false;
220  return result >= bufferType.getNumElements();
221 }
222 
223 namespace {
224 template <typename OpType>
225 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
227 
228  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
229  if (!staticallyOutOfBounds(op))
230  return failure();
231  Type loadType = op.getResult().getType();
232  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
233  rw.getZeroAttr(loadType));
234  return success();
235  }
236 };
237 
238 template <typename OpType>
239 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
241 
242  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
243  if (!staticallyOutOfBounds(op))
244  return failure();
245 
246  rw.eraseOp(op);
247  return success();
248  }
249 };
250 } // end namespace
251 
252 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
253  MLIRContext *context) {
254  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
255 }
256 
257 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
258  MLIRContext *context) {
259  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
260 }
261 
262 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
263  RewritePatternSet &results, MLIRContext *context) {
264  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
265 }
266 
267 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
268  RewritePatternSet &results, MLIRContext *context) {
269  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
270 }
271 
272 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
273  RewritePatternSet &results, MLIRContext *context) {
274  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
275 }
276 
277 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
278  RewritePatternSet &results, MLIRContext *context) {
279  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
280 }
281 
282 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
283  RewritePatternSet &results, MLIRContext *context) {
284  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
285  context);
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // WMMAOp
290 //===----------------------------------------------------------------------===//
291 LogicalResult WMMAOp::verify() {
292  Type sourceAType = getSourceA().getType();
293  Type sourceBType = getSourceB().getType();
294  Type destType = getDestC().getType();
295 
296  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
297  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
298  VectorType destVectorType = dyn_cast<VectorType>(destType);
299 
300  Type sourceAElemType = sourceVectorAType.getElementType();
301  Type sourceBElemType = sourceVectorBType.getElementType();
302  Type destElemType = destVectorType.getElementType();
303 
304  if (sourceVectorAType.getNumElements() !=
305  sourceVectorBType.getNumElements()) {
306  return emitOpError("source vectors have different lengths: ")
307  << sourceVectorAType << " vs. " << sourceVectorBType;
308  }
309 
310  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
311  bool isSrcFloat =
312  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
313  sourceAElemType);
314 
315  if (isDestFloat && !isSrcFloat) {
316  return emitOpError("Expected float sources with float destination");
317  }
318 
319  if (!isDestFloat && isSrcFloat) {
320  return emitOpError("Expected int sources with int destination");
321  }
322 
323  if (sourceAElemType != sourceBElemType &&
324  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
325  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
326  return emitOpError(
327  "source element types much match (except for fp8) but have ")
328  << sourceAType << " and " << sourceBType;
329  }
330  return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // MFMAOp
335 //===----------------------------------------------------------------------===//
336 LogicalResult MFMAOp::verify() {
337  constexpr uint32_t waveSize = 64;
338  Builder b(getContext());
339 
340  Type sourceType = getSourceA().getType();
341  Type destType = getDestC().getType();
342 
343  Type sourceElem = sourceType, destElem = destType;
344  uint32_t sourceLen = 1, destLen = 1;
345  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
346  sourceLen = sourceVector.getNumElements();
347  sourceElem = sourceVector.getElementType();
348  }
349  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
350  destLen = destVector.getNumElements();
351  destElem = destVector.getElementType();
352  }
353 
354  Type sourceBType = getSourceB().getType();
355  if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
356  int64_t sourceBLen = 1;
357  Type sourceBElem = sourceBType;
358  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
359  sourceBLen = sourceBVector.getNumElements();
360  sourceBElem = sourceBVector.getElementType();
361  }
362  if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
363  !sourceBElem.isFloat(4))
364  return emitOpError("expected both source operands to have small-float "
365  "elements if one does");
366  if (sourceLen != sourceBLen)
367  return emitOpError(
368  "expected both small-float source vectors to have the same length");
369  } else {
370  if (sourceType != sourceBType)
371  return emitOpError("expected both non-small-float source operand types "
372  "to match exactly");
373  }
374  // Normalize the wider integer types the compiler expects to i8
375  if (sourceElem.isInteger(32)) {
376  sourceLen *= 4;
377  sourceElem = b.getI8Type();
378  }
379  if (sourceElem.isInteger(64)) {
380  sourceLen *= 8;
381  sourceElem = b.getI8Type();
382  }
383 
384  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
385  if (sourceLen != numSourceElems)
386  return emitOpError("expected " + Twine(numSourceElems) +
387  " source values for this operation but got " +
388  Twine(sourceLen));
389 
390  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
391  if (destLen != numDestElems)
392  return emitOpError("expected " + Twine(numDestElems) +
393  " result values for this operation but got " +
394  Twine(destLen));
395 
396  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
397  return emitOpError(
398  "double-precision ops do not support permuting lanes of B");
399  if (destElem.isF64() && getCbsz() != 0)
400  return emitOpError(
401  "double-precision ops do not support permuting lanes of A");
402  if (getAbid() >= (1u << getCbsz()))
403  return emitOpError(
404  "block ID for permuting A (abid) must be below 2 ** cbsz");
405 
406  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
407  return emitOpError(
408  "negation flags only available for double-precision operations");
409 
410  return success();
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // DPPOp
415 //===----------------------------------------------------------------------===//
416 LogicalResult DPPOp::verify() {
417  Type srcType = getSrc().getType();
418  if (srcType.getIntOrFloatBitWidth() > 64) {
419  return emitOpError("integer and floating point types larger than 64 bits "
420  "are not supported");
421  }
422 
423  DPPPerm kind = getKind();
424  Attribute permArgument = getPermArgument().value_or(Attribute{});
425 
426  switch (kind) {
427 
428  case DPPPerm::quad_perm: {
429  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
430  if (!quadPermAttr || quadPermAttr.size() != 4) {
431  return emitOpError("quad_perm attribute must have exactly 4 elements");
432  }
433  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
434  int32_t num = elem.getInt();
435  if (num < 0 || num > 3) {
436  return emitOpError(
437  "Each element of quad_perm must be in the range [0, 3]");
438  }
439  }
440  } break;
441 
442  case DPPPerm::row_shl:
443  case DPPPerm::row_shr:
444  case DPPPerm::row_ror: {
445  if (!permArgument) {
446  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
447  "' value not specified");
448  }
449  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
450  uint32_t attrValue = intAttr.getInt();
451  if (attrValue < 1 || attrValue > 15) {
452  return emitOpError("Attribute value must be between 1 and 15");
453  }
454  }
455  } break;
456 
457  case DPPPerm::wave_shl:
458  case DPPPerm::wave_shr:
459  case DPPPerm::wave_rol:
460  case DPPPerm::wave_ror:
461  case DPPPerm::row_mirror:
462  case DPPPerm::row_half_mirror:
463  case DPPPerm::row_bcast_15:
464  case DPPPerm::row_bcast_31: {
465  if (permArgument && !isa<UnitAttr>(permArgument)) {
466  return emitOpError("Expected unit attribute for permArgument, but found "
467  "non-trivial argument");
468  }
469  break;
470  }
471  }
472  return success();
473 }
474 
475 LogicalResult GatherToLDSOp::verify() {
476  MemRefType srcType = cast<MemRefType>(getSrc().getType());
477  MemRefType dstType = cast<MemRefType>(getDst().getType());
478 
480  return emitOpError(
481  "destination types must have static shape and contiguous");
482 
483  auto elemType = srcType.getElementType();
484  // Check $src and $dst element types are the same.
485  if (elemType != dstType.getElementType())
486  return emitOpError("source and destination element types must match");
487 
488  // copy type sizes should be 1, 2, or 4 bytes.
489  auto transferType = getTransferType();
490  size_t transferSize;
491  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
492  transferSize = vectorTransfer.getNumElements() *
493  vectorTransfer.getElementTypeBitWidth();
494  } else {
495  transferSize = transferType.getIntOrFloatBitWidth();
496  }
497  if (transferSize != 8 && transferSize != 16 && transferSize != 32)
498  return emitOpError("Transfering type size must be 8, 16, or 32 bits");
499 
500  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
501  return emitOpError("source memory address space must be Global");
502 
503  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
504  return emitOpError("destination memory address space must be Workgroup");
505 
506  return success();
507 }
508 
509 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
510 
511 #define GET_ATTRDEF_CLASSES
512 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
513 
514 #define GET_OP_CLASSES
515 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:192
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:815
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318