MLIR  20.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 #include <limits>
29 #include <optional>
30 
31 using namespace mlir;
32 using namespace mlir::amdgpu;
33 
34 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
35 
36 void AMDGPUDialect::initialize() {
37  addOperations<
38 #define GET_OP_LIST
39 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
40  >();
41  addAttributes<
42 #define GET_ATTRDEF_LIST
43 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
44  >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // 8-bit float ops
49 //===----------------------------------------------------------------------===//
50 LogicalResult PackedTrunc2xFp8Op::verify() {
51  if (getExisting() && getExisting().getType() != getResult().getType())
52  return emitOpError("existing values must have same type as result");
53  return success();
54 }
55 
56 LogicalResult PackedStochRoundFp8Op::verify() {
57  if (getExisting() && getExisting().getType() != getResult().getType())
58  return emitOpError("existing values must have same type as result");
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // RawBuffer*Op
64 //===----------------------------------------------------------------------===//
65 template <typename T>
66 static LogicalResult verifyRawBufferOp(T &op) {
67  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
68  Attribute memorySpace = bufferType.getMemorySpace();
69  bool isGlobal = false;
70  if (!memorySpace)
71  isGlobal = true;
72  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
73  isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
74  else if (auto gpuMemorySpace =
75  llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
76  isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
77 
78  if (!isGlobal)
79  return op.emitOpError(
80  "Buffer ops must operate on a memref in global memory");
81  if (!bufferType.hasRank())
82  return op.emitOpError(
83  "Cannot meaningfully buffer_store to an unranked memref");
84  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
85  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
86  " indices to memref");
87  return success();
88 }
89 
90 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
91 
92 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
93 
94 LogicalResult RawBufferAtomicFaddOp::verify() {
95  return verifyRawBufferOp(*this);
96 }
97 
98 LogicalResult RawBufferAtomicFmaxOp::verify() {
99  return verifyRawBufferOp(*this);
100 }
101 
102 LogicalResult RawBufferAtomicSmaxOp::verify() {
103  return verifyRawBufferOp(*this);
104 }
105 
106 LogicalResult RawBufferAtomicUminOp::verify() {
107  return verifyRawBufferOp(*this);
108 }
109 
110 LogicalResult RawBufferAtomicCmpswapOp::verify() {
111  return verifyRawBufferOp(*this);
112 }
113 
114 static std::optional<uint32_t> getConstantUint32(Value v) {
115  APInt cst;
116  if (!v.getType().isInteger(32))
117  return std::nullopt;
118  if (matchPattern(v, m_ConstantInt(&cst)))
119  return cst.getZExtValue();
120  return std::nullopt;
121 }
122 
123 template <typename OpType>
124 static bool staticallyOutOfBounds(OpType op) {
125  if (!op.getBoundsCheck())
126  return false;
127  MemRefType bufferType = op.getMemref().getType();
128  if (!bufferType.hasStaticShape())
129  return false;
130  int64_t offset;
131  SmallVector<int64_t> strides;
132  if (failed(getStridesAndOffset(bufferType, strides, offset)))
133  return false;
134  int64_t result = offset + op.getIndexOffset().value_or(0);
135  if (op.getSgprOffset()) {
136  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
137  if (!sgprOffset)
138  return false;
139  result += *sgprOffset;
140  }
141  if (strides.size() != op.getIndices().size())
142  return false;
143  int64_t indexVal = 0;
144  for (auto pair : llvm::zip(strides, op.getIndices())) {
145  int64_t stride = std::get<0>(pair);
146  Value idx = std::get<1>(pair);
147  std::optional<uint32_t> idxVal = getConstantUint32(idx);
148  if (!idxVal)
149  return false;
150  indexVal += stride * *idxVal;
151  }
152  result += indexVal;
153  if (result > std::numeric_limits<uint32_t>::max())
154  // Overflow means don't drop
155  return false;
156  return result >= bufferType.getNumElements();
157 }
158 
159 namespace {
160 template <typename OpType>
161 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
163 
164  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
165  if (!staticallyOutOfBounds(op))
166  return failure();
167  Type loadType = op.getResult().getType();
168  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
169  rw.getZeroAttr(loadType));
170  return success();
171  }
172 };
173 
174 template <typename OpType>
175 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
177 
178  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
179  if (!staticallyOutOfBounds(op))
180  return failure();
181 
182  rw.eraseOp(op);
183  return success();
184  }
185 };
186 } // end namespace
187 
188 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
189  MLIRContext *context) {
190  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
191 }
192 
193 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
194  MLIRContext *context) {
195  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
196 }
197 
198 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
199  RewritePatternSet &results, MLIRContext *context) {
200  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
201 }
202 
203 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
204  RewritePatternSet &results, MLIRContext *context) {
205  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
206 }
207 
208 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
209  RewritePatternSet &results, MLIRContext *context) {
210  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
211 }
212 
213 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
214  RewritePatternSet &results, MLIRContext *context) {
215  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
216 }
217 
218 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
219  RewritePatternSet &results, MLIRContext *context) {
220  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
221  context);
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // WMMAOp
226 //===----------------------------------------------------------------------===//
227 LogicalResult WMMAOp::verify() {
228  Type sourceAType = getSourceA().getType();
229  Type destType = getDestC().getType();
230 
231  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
232  VectorType destVectorType = dyn_cast<VectorType>(destType);
233 
234  Type sourceAElemType = sourceVectorAType.getElementType();
235  Type destElemType = destVectorType.getElementType();
236 
237  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
238  bool isSrcFloat =
239  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
240  sourceAElemType);
241 
242  if (isDestFloat && !isSrcFloat) {
243  return emitOpError("Expected float sources with float destination");
244  }
245 
246  if (!isDestFloat && isSrcFloat) {
247  return emitOpError("Expected int sources with int destination");
248  }
249 
250  return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // MFMAOp
255 //===----------------------------------------------------------------------===//
256 LogicalResult MFMAOp::verify() {
257  constexpr uint32_t waveSize = 64;
258  Builder b(getContext());
259 
260  Type sourceType = getSourceA().getType();
261  Type destType = getDestC().getType();
262 
263  Type sourceElem = sourceType, destElem = destType;
264  uint32_t sourceLen = 1, destLen = 1;
265  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
266  sourceLen = sourceVector.getNumElements();
267  sourceElem = sourceVector.getElementType();
268  }
269  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
270  destLen = destVector.getNumElements();
271  destElem = destVector.getElementType();
272  }
273 
274  Type sourceBType = getSourceB().getType();
275  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
276  int64_t sourceBLen = 1;
277  Type sourceBElem = sourceBType;
278  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279  sourceBLen = sourceBVector.getNumElements();
280  sourceBElem = sourceBVector.getElementType();
281  }
282  if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
283  return emitOpError("expected both source operands to have f8 elements");
284  if (sourceLen != sourceBLen)
285  return emitOpError(
286  "expected both f8 source vectors to have the same length");
287  } else {
288  if (sourceType != sourceBType)
289  return emitOpError(
290  "expected both non-f8 source operand types to match exactly");
291  }
292  // Normalize the wider integer types the compiler expects to i8
293  if (sourceElem.isInteger(32)) {
294  sourceLen *= 4;
295  sourceElem = b.getI8Type();
296  }
297  if (sourceElem.isInteger(64)) {
298  sourceLen *= 8;
299  sourceElem = b.getI8Type();
300  }
301 
302  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
303  if (sourceLen != numSourceElems)
304  return emitOpError("expected " + Twine(numSourceElems) +
305  " source values for this operation but got " +
306  Twine(sourceLen));
307 
308  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
309  if (destLen != numDestElems)
310  return emitOpError("expected " + Twine(numDestElems) +
311  " result values for this operation but got " +
312  Twine(destLen));
313 
314  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
315  return emitOpError(
316  "double-precision ops do not support permuting lanes of B");
317  if (destElem.isF64() && getCbsz() != 0)
318  return emitOpError(
319  "double-precision ops do not support permuting lanes of A");
320  if (getAbid() >= (1u << getCbsz()))
321  return emitOpError(
322  "block ID for permuting A (abid) must be below 2 ** cbsz");
323 
324  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
325  return emitOpError(
326  "negation flags only available for double-precision operations");
327 
328  return success();
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // DPPOp
333 //===----------------------------------------------------------------------===//
334 LogicalResult DPPOp::verify() {
335  Type srcType = getSrc().getType();
336  if (srcType.getIntOrFloatBitWidth() > 64) {
337  return emitOpError("integer and floating point types larger than 64 bits "
338  "are not supported");
339  }
340 
341  DPPPerm kind = getKind();
342  Attribute permArgument = getPermArgument().value_or(Attribute{});
343 
344  switch (kind) {
345 
346  case DPPPerm::quad_perm: {
347  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
348  if (!quadPermAttr || quadPermAttr.size() != 4) {
349  return emitOpError("quad_perm attribute must have exactly 4 elements");
350  }
351  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
352  uint32_t num = elem.getInt();
353  if (num < 0 || num > 3) {
354  return emitOpError(
355  "Each element of quad_perm must be in the range [0, 3]");
356  }
357  }
358  } break;
359 
360  case DPPPerm::row_shl:
361  case DPPPerm::row_shr:
362  case DPPPerm::row_ror: {
363  if (!permArgument) {
364  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
365  "' value not specified");
366  }
367  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
368  uint32_t attrValue = intAttr.getInt();
369  if (attrValue < 1 || attrValue > 15) {
370  return emitOpError("Attribute value must be between 1 and 15");
371  }
372  }
373  } break;
374 
375  case DPPPerm::wave_shl:
376  case DPPPerm::wave_shr:
377  case DPPPerm::wave_rol:
378  case DPPPerm::wave_ror:
379  case DPPPerm::row_mirror:
380  case DPPPerm::row_half_mirror:
381  case DPPPerm::row_bcast_15:
382  case DPPPerm::row_bcast_31: {
383  if (permArgument && !isa<UnitAttr>(permArgument)) {
384  return emitOpError("Expected unit attribute for permArgument, but found "
385  "non-trivial argument");
386  }
387  break;
388  }
389  }
390  return success();
391 }
392 
393 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
394 
395 #define GET_ATTRDEF_CLASSES
396 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
397 
398 #define GET_OP_CLASSES
399 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:355
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:46
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:43
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358