MLIR  19.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #include <limits>
28 #include <optional>
29 
30 using namespace mlir;
31 using namespace mlir::amdgpu;
32 
33 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
34 
35 void AMDGPUDialect::initialize() {
36  addOperations<
37 #define GET_OP_LIST
38 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
39  >();
40  addAttributes<
41 #define GET_ATTRDEF_LIST
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
43  >();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // 8-bit float ops
48 //===----------------------------------------------------------------------===//
50  if (getExisting() && getExisting().getType() != getResult().getType())
51  return emitOpError("existing values must have same type as result");
52  return success();
53 }
54 
56  if (getExisting() && getExisting().getType() != getResult().getType())
57  return emitOpError("existing values must have same type as result");
58  return success();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // RawBuffer*Op
63 //===----------------------------------------------------------------------===//
64 template <typename T>
66  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
67  Attribute memorySpace = bufferType.getMemorySpace();
68  bool isGlobal = false;
69  if (!memorySpace)
70  isGlobal = true;
71  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
72  isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
73  else if (auto gpuMemorySpace =
74  llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
75  isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
76 
77  if (!isGlobal)
78  return op.emitOpError(
79  "Buffer ops must operate on a memref in global memory");
80  if (!bufferType.hasRank())
81  return op.emitOpError(
82  "Cannot meaningfully buffer_store to an unranked memref");
83  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
84  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
85  " indices to memref");
86  return success();
87 }
88 
90 
92 
94  return verifyRawBufferOp(*this);
95 }
96 
98  return verifyRawBufferOp(*this);
99 }
100 
102  return verifyRawBufferOp(*this);
103 }
104 
106  return verifyRawBufferOp(*this);
107 }
108 
110  return verifyRawBufferOp(*this);
111 }
112 
113 static std::optional<uint32_t> getConstantUint32(Value v) {
114  APInt cst;
115  if (!v.getType().isInteger(32))
116  return std::nullopt;
117  if (matchPattern(v, m_ConstantInt(&cst)))
118  return cst.getZExtValue();
119  return std::nullopt;
120 }
121 
122 template <typename OpType>
123 static bool staticallyOutOfBounds(OpType op) {
124  if (!op.getBoundsCheck())
125  return false;
126  MemRefType bufferType = op.getMemref().getType();
127  if (!bufferType.hasStaticShape())
128  return false;
129  int64_t offset;
130  SmallVector<int64_t> strides;
131  if (failed(getStridesAndOffset(bufferType, strides, offset)))
132  return false;
133  int64_t result = offset + op.getIndexOffset().value_or(0);
134  if (op.getSgprOffset()) {
135  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
136  if (!sgprOffset)
137  return false;
138  result += *sgprOffset;
139  }
140  if (strides.size() != op.getIndices().size())
141  return false;
142  int64_t indexVal = 0;
143  for (auto pair : llvm::zip(strides, op.getIndices())) {
144  int64_t stride = std::get<0>(pair);
145  Value idx = std::get<1>(pair);
146  std::optional<uint32_t> idxVal = getConstantUint32(idx);
147  if (!idxVal)
148  return false;
149  indexVal += stride * *idxVal;
150  }
151  result += indexVal;
152  if (result > std::numeric_limits<uint32_t>::max())
153  // Overflow means don't drop
154  return false;
155  return result >= bufferType.getNumElements();
156 }
157 
158 namespace {
159 template <typename OpType>
160 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
162 
163  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
164  if (!staticallyOutOfBounds(op))
165  return failure();
166  Type loadType = op.getResult().getType();
167  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
168  rw.getZeroAttr(loadType));
169  return success();
170  }
171 };
172 
173 template <typename OpType>
174 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
176 
177  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
178  if (!staticallyOutOfBounds(op))
179  return failure();
180 
181  rw.eraseOp(op);
182  return success();
183  }
184 };
185 } // end namespace
186 
187 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
188  MLIRContext *context) {
189  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
190 }
191 
192 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
193  MLIRContext *context) {
194  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
195 }
196 
197 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
198  RewritePatternSet &results, MLIRContext *context) {
199  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
200 }
201 
202 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
203  RewritePatternSet &results, MLIRContext *context) {
204  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
205 }
206 
207 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
208  RewritePatternSet &results, MLIRContext *context) {
209  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
210 }
211 
212 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
213  RewritePatternSet &results, MLIRContext *context) {
214  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
215 }
216 
217 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
218  RewritePatternSet &results, MLIRContext *context) {
219  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
220  context);
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // WMMAOp
225 //===----------------------------------------------------------------------===//
227  Type sourceAType = getSourceA().getType();
228  Type destType = getDestC().getType();
229 
230  VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
231  VectorType destVectorType = destType.dyn_cast<VectorType>();
232 
233  Type sourceAElemType = sourceVectorAType.getElementType();
234  Type destElemType = destVectorType.getElementType();
235 
236  bool isDestFloat =
237  (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
238  bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
239 
240  if (isDestFloat && !isSrcFloat) {
241  return emitOpError("Expected float sources with float destination");
242  }
243 
244  if (!isDestFloat && isSrcFloat) {
245  return emitOpError("Expected int sources with int destination");
246  }
247 
248  return success();
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // MFMAOp
253 //===----------------------------------------------------------------------===//
255  constexpr uint32_t waveSize = 64;
256  Builder b(getContext());
257 
258  Type sourceType = getSourceA().getType();
259  Type destType = getDestC().getType();
260 
261  Type sourceElem = sourceType, destElem = destType;
262  uint32_t sourceLen = 1, destLen = 1;
263  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
264  sourceLen = sourceVector.getNumElements();
265  sourceElem = sourceVector.getElementType();
266  }
267  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
268  destLen = destVector.getNumElements();
269  destElem = destVector.getElementType();
270  }
271 
272  Type sourceBType = getSourceB().getType();
273  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
274  int64_t sourceBLen = 1;
275  Type sourceBElem = sourceBType;
276  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
277  sourceBLen = sourceBVector.getNumElements();
278  sourceBElem = sourceBVector.getElementType();
279  }
280  if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
281  return emitOpError("expected both source operands to have f8 elements");
282  if (sourceLen != sourceBLen)
283  return emitOpError(
284  "expected both f8 source vectors to have the same length");
285  } else {
286  if (sourceType != sourceBType)
287  return emitOpError(
288  "expected both non-f8 source operand types to match exactly");
289  }
290  // Normalize the wider integer types the compiler expects to i8
291  if (sourceElem.isInteger(32)) {
292  sourceLen *= 4;
293  sourceElem = b.getI8Type();
294  }
295  if (sourceElem.isInteger(64)) {
296  sourceLen *= 8;
297  sourceElem = b.getI8Type();
298  }
299 
300  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
301  if (sourceLen != numSourceElems)
302  return emitOpError("expected " + Twine(numSourceElems) +
303  " source values for this operation but got " +
304  Twine(sourceLen));
305 
306  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
307  if (destLen != numDestElems)
308  return emitOpError("expected " + Twine(numDestElems) +
309  " result values for this operation but got " +
310  Twine(destLen));
311 
312  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
313  return emitOpError(
314  "double-precision ops do not support permuting lanes of B");
315  if (destElem.isF64() && getCbsz() != 0)
316  return emitOpError(
317  "double-precision ops do not support permuting lanes of A");
318  if (getAbid() >= (1u << getCbsz()))
319  return emitOpError(
320  "block ID for permuting A (abid) must be below 2 ** cbsz");
321 
322  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
323  return emitOpError(
324  "negation flags only available for double-precision operations");
325 
326  return success();
327 }
328 
329 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
330 
331 #define GET_ATTRDEF_CLASSES
332 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
333 
334 #define GET_OP_CLASSES
335 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U dyn_cast() const
Definition: Types.h:330
bool isF32() const
Definition: Types.cpp:51
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:42
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isF16() const
Definition: Types.cpp:49
bool isBF16() const
Definition: Types.cpp:48
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:39
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
uint64_t getN(LevelType lt)
Definition: Enums.h:438
uint64_t getM(LevelType lt)
Definition: Enums.h:439
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358