MLIR  21.0.0git
AMDGPUDialect.cpp
Go to the documentation of this file.
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 #include <limits>
29 #include <optional>
30 
31 using namespace mlir;
32 using namespace mlir::amdgpu;
33 
34 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
35 
36 void AMDGPUDialect::initialize() {
37  addOperations<
38 #define GET_OP_LIST
39 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
40  >();
41  addAttributes<
42 #define GET_ATTRDEF_LIST
43 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
44  >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // 8-bit float ops
49 //===----------------------------------------------------------------------===//
50 LogicalResult PackedTrunc2xFp8Op::verify() {
51  if (getExisting() && getExisting().getType() != getResult().getType())
52  return emitOpError("existing values must have same type as result");
53  return success();
54 }
55 
56 LogicalResult PackedStochRoundFp8Op::verify() {
57  if (getExisting() && getExisting().getType() != getResult().getType())
58  return emitOpError("existing values must have same type as result");
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // FatRawBuferCastOp
64 //===----------------------------------------------------------------------===//
65 
66 /// Convert the type `source` to one with the same sizes and strides - and
67 /// offset, unless `stripOffset` is true, in which case the offset is reset to
68 /// 0, if the offset should be reset but the layout of `source` isn't either the
69 /// identity layout or a strided layout, this function fails.
70 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
71  bool resetOffset) {
72  MLIRContext *ctx = source.getContext();
73  MemRefType::Builder mb(source);
74  mb.setMemorySpace(
75  amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
76  MemRefLayoutAttrInterface layout = source.getLayout();
77  if (resetOffset && !layout.isIdentity()) {
78  auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
79  if (!stridedLayout)
80  return failure();
81  mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
82  }
83  return (MemRefType)(mb);
84 }
85 
86 LogicalResult FatRawBufferCastOp::inferReturnTypes(
87  MLIRContext *context, std::optional<Location> location, ValueRange operands,
88  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
89  SmallVectorImpl<Type> &inferredReturnTypes) {
90  Adaptor adaptor(operands, attributes, properties, regions);
91  auto sourceType =
92  dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
93  if (!sourceType)
94  return failure();
95  FailureOr<MemRefType> resultType =
96  getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
97  if (failed(resultType))
98  return failure();
99  inferredReturnTypes = SmallVector<Type>{*resultType};
100  return success();
101 }
102 
103 LogicalResult FatRawBufferCastOp::verify() {
104  FailureOr<MemRefType> expectedResultType =
105  getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
106  if (failed(expectedResultType))
107  return emitOpError("source type ")
108  << getSource().getType() << " can't have its offset reset";
109  if (getResult().getType() != *expectedResultType)
110  return emitOpError("expected result type to be ")
111  << *expectedResultType << " but got " << getResult().getType();
112  return success();
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // RawBuffer*Op
117 //===----------------------------------------------------------------------===//
118 template <typename T>
119 static LogicalResult verifyRawBufferOp(T &op) {
120  MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
121  Attribute memorySpace = bufferType.getMemorySpace();
122  bool isGlobal = false;
123  if (!memorySpace)
124  isGlobal = true;
125  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
126  isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
127  else if (auto gpuMemorySpace =
128  llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
129  isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
130 
131  if (!isGlobal)
132  return op.emitOpError(
133  "Buffer ops must operate on a memref in global memory");
134  if (!bufferType.hasRank())
135  return op.emitOpError(
136  "Cannot meaningfully buffer_store to an unranked memref");
137  if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
138  return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
139  " indices to memref");
140  return success();
141 }
142 
143 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
144 
145 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
146 
147 LogicalResult RawBufferAtomicFaddOp::verify() {
148  return verifyRawBufferOp(*this);
149 }
150 
151 LogicalResult RawBufferAtomicFmaxOp::verify() {
152  return verifyRawBufferOp(*this);
153 }
154 
155 LogicalResult RawBufferAtomicSmaxOp::verify() {
156  return verifyRawBufferOp(*this);
157 }
158 
159 LogicalResult RawBufferAtomicUminOp::verify() {
160  return verifyRawBufferOp(*this);
161 }
162 
163 LogicalResult RawBufferAtomicCmpswapOp::verify() {
164  return verifyRawBufferOp(*this);
165 }
166 
167 static std::optional<uint32_t> getConstantUint32(Value v) {
168  APInt cst;
169  if (!v.getType().isInteger(32))
170  return std::nullopt;
171  if (matchPattern(v, m_ConstantInt(&cst)))
172  return cst.getZExtValue();
173  return std::nullopt;
174 }
175 
176 template <typename OpType>
177 static bool staticallyOutOfBounds(OpType op) {
178  if (!op.getBoundsCheck())
179  return false;
180  MemRefType bufferType = op.getMemref().getType();
181  if (!bufferType.hasStaticShape())
182  return false;
183  int64_t offset;
184  SmallVector<int64_t> strides;
185  if (failed(bufferType.getStridesAndOffset(strides, offset)))
186  return false;
187  int64_t result = offset + op.getIndexOffset().value_or(0);
188  if (op.getSgprOffset()) {
189  std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
190  if (!sgprOffset)
191  return false;
192  result += *sgprOffset;
193  }
194  if (strides.size() != op.getIndices().size())
195  return false;
196  int64_t indexVal = 0;
197  for (auto pair : llvm::zip(strides, op.getIndices())) {
198  int64_t stride = std::get<0>(pair);
199  Value idx = std::get<1>(pair);
200  std::optional<uint32_t> idxVal = getConstantUint32(idx);
201  if (!idxVal)
202  return false;
203  indexVal += stride * *idxVal;
204  }
205  result += indexVal;
206  if (result > std::numeric_limits<uint32_t>::max())
207  // Overflow means don't drop
208  return false;
209  return result >= bufferType.getNumElements();
210 }
211 
212 namespace {
213 template <typename OpType>
214 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
216 
217  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
218  if (!staticallyOutOfBounds(op))
219  return failure();
220  Type loadType = op.getResult().getType();
221  rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
222  rw.getZeroAttr(loadType));
223  return success();
224  }
225 };
226 
227 template <typename OpType>
228 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
230 
231  LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
232  if (!staticallyOutOfBounds(op))
233  return failure();
234 
235  rw.eraseOp(op);
236  return success();
237  }
238 };
239 } // end namespace
240 
241 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
242  MLIRContext *context) {
243  results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
244 }
245 
246 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
247  MLIRContext *context) {
248  results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
249 }
250 
251 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
252  RewritePatternSet &results, MLIRContext *context) {
253  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
254 }
255 
256 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
257  RewritePatternSet &results, MLIRContext *context) {
258  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
259 }
260 
261 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
262  RewritePatternSet &results, MLIRContext *context) {
263  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
264 }
265 
266 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
267  RewritePatternSet &results, MLIRContext *context) {
268  results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
269 }
270 
271 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
272  RewritePatternSet &results, MLIRContext *context) {
273  results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
274  context);
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // WMMAOp
279 //===----------------------------------------------------------------------===//
280 LogicalResult WMMAOp::verify() {
281  Type sourceAType = getSourceA().getType();
282  Type sourceBType = getSourceB().getType();
283  Type destType = getDestC().getType();
284 
285  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
286  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
287  VectorType destVectorType = dyn_cast<VectorType>(destType);
288 
289  Type sourceAElemType = sourceVectorAType.getElementType();
290  Type sourceBElemType = sourceVectorBType.getElementType();
291  Type destElemType = destVectorType.getElementType();
292 
293  if (sourceVectorAType.getNumElements() !=
294  sourceVectorBType.getNumElements()) {
295  return emitOpError("source vectors have different lengths: ")
296  << sourceVectorAType << " vs. " << sourceVectorBType;
297  }
298 
299  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
300  bool isSrcFloat =
301  isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
302  sourceAElemType);
303 
304  if (isDestFloat && !isSrcFloat) {
305  return emitOpError("Expected float sources with float destination");
306  }
307 
308  if (!isDestFloat && isSrcFloat) {
309  return emitOpError("Expected int sources with int destination");
310  }
311 
312  if (sourceAElemType != sourceBElemType &&
313  !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
314  isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
315  return emitOpError(
316  "source element types much match (except for fp8) but have ")
317  << sourceAType << " and " << sourceBType;
318  }
319  return success();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // MFMAOp
324 //===----------------------------------------------------------------------===//
325 LogicalResult MFMAOp::verify() {
326  constexpr uint32_t waveSize = 64;
327  Builder b(getContext());
328 
329  Type sourceType = getSourceA().getType();
330  Type destType = getDestC().getType();
331 
332  Type sourceElem = sourceType, destElem = destType;
333  uint32_t sourceLen = 1, destLen = 1;
334  if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
335  sourceLen = sourceVector.getNumElements();
336  sourceElem = sourceVector.getElementType();
337  }
338  if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
339  destLen = destVector.getNumElements();
340  destElem = destVector.getElementType();
341  }
342 
343  Type sourceBType = getSourceB().getType();
344  if (sourceElem.isFloat(8)) {
345  int64_t sourceBLen = 1;
346  Type sourceBElem = sourceBType;
347  if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
348  sourceBLen = sourceBVector.getNumElements();
349  sourceBElem = sourceBVector.getElementType();
350  }
351  if (!sourceBElem.isFloat(8))
352  return emitOpError("expected both source operands to have f8 elements");
353  if (sourceLen != sourceBLen)
354  return emitOpError(
355  "expected both f8 source vectors to have the same length");
356  } else {
357  if (sourceType != sourceBType)
358  return emitOpError(
359  "expected both non-f8 source operand types to match exactly");
360  }
361  // Normalize the wider integer types the compiler expects to i8
362  if (sourceElem.isInteger(32)) {
363  sourceLen *= 4;
364  sourceElem = b.getI8Type();
365  }
366  if (sourceElem.isInteger(64)) {
367  sourceLen *= 8;
368  sourceElem = b.getI8Type();
369  }
370 
371  int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
372  if (sourceLen != numSourceElems)
373  return emitOpError("expected " + Twine(numSourceElems) +
374  " source values for this operation but got " +
375  Twine(sourceLen));
376 
377  int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
378  if (destLen != numDestElems)
379  return emitOpError("expected " + Twine(numDestElems) +
380  " result values for this operation but got " +
381  Twine(destLen));
382 
383  if (destElem.isF64() && getBlgp() != MFMAPermB::none)
384  return emitOpError(
385  "double-precision ops do not support permuting lanes of B");
386  if (destElem.isF64() && getCbsz() != 0)
387  return emitOpError(
388  "double-precision ops do not support permuting lanes of A");
389  if (getAbid() >= (1u << getCbsz()))
390  return emitOpError(
391  "block ID for permuting A (abid) must be below 2 ** cbsz");
392 
393  if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
394  return emitOpError(
395  "negation flags only available for double-precision operations");
396 
397  return success();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // DPPOp
402 //===----------------------------------------------------------------------===//
403 LogicalResult DPPOp::verify() {
404  Type srcType = getSrc().getType();
405  if (srcType.getIntOrFloatBitWidth() > 64) {
406  return emitOpError("integer and floating point types larger than 64 bits "
407  "are not supported");
408  }
409 
410  DPPPerm kind = getKind();
411  Attribute permArgument = getPermArgument().value_or(Attribute{});
412 
413  switch (kind) {
414 
415  case DPPPerm::quad_perm: {
416  auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
417  if (!quadPermAttr || quadPermAttr.size() != 4) {
418  return emitOpError("quad_perm attribute must have exactly 4 elements");
419  }
420  for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
421  int32_t num = elem.getInt();
422  if (num < 0 || num > 3) {
423  return emitOpError(
424  "Each element of quad_perm must be in the range [0, 3]");
425  }
426  }
427  } break;
428 
429  case DPPPerm::row_shl:
430  case DPPPerm::row_shr:
431  case DPPPerm::row_ror: {
432  if (!permArgument) {
433  return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
434  "' value not specified");
435  }
436  if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
437  uint32_t attrValue = intAttr.getInt();
438  if (attrValue < 1 || attrValue > 15) {
439  return emitOpError("Attribute value must be between 1 and 15");
440  }
441  }
442  } break;
443 
444  case DPPPerm::wave_shl:
445  case DPPPerm::wave_shr:
446  case DPPPerm::wave_rol:
447  case DPPPerm::wave_ror:
448  case DPPPerm::row_mirror:
449  case DPPPerm::row_half_mirror:
450  case DPPPerm::row_bcast_15:
451  case DPPPerm::row_bcast_31: {
452  if (permArgument && !isa<UnitAttr>(permArgument)) {
453  return emitOpError("Expected unit attribute for permArgument, but found "
454  "non-trivial argument");
455  }
456  break;
457  }
458  }
459  return success();
460 }
461 
462 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
463 
464 #define GET_ATTRDEF_CLASSES
465 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
466 
467 #define GET_OP_CLASSES
468 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1179::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:187
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:192
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition: Types.cpp:45
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358