26 #include "llvm/ADT/TypeSwitch.h"
34 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
36 void AMDGPUDialect::initialize() {
39 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
42 #define GET_ATTRDEF_LIST
43 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
51 if (getExisting() && getExisting().
getType() != getResult().
getType())
52 return emitOpError(
"existing values must have same type as result");
57 if (getExisting() && getExisting().
getType() != getResult().
getType())
58 return emitOpError(
"existing values must have same type as result");
67 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
68 Attribute memorySpace = bufferType.getMemorySpace();
69 bool isGlobal =
false;
72 else if (
auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
73 isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
74 else if (
auto gpuMemorySpace =
75 llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
76 isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
80 "Buffer ops must operate on a memref in global memory");
81 if (!bufferType.hasRank())
83 "Cannot meaningfully buffer_store to an unranked memref");
84 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
85 return op.
emitOpError(
"Expected " + Twine(bufferType.getRank()) +
86 " indices to memref");
119 return cst.getZExtValue();
123 template <
typename OpType>
125 if (!op.getBoundsCheck())
127 MemRefType bufferType = op.getMemref().getType();
128 if (!bufferType.hasStaticShape())
134 int64_t result = offset + op.getIndexOffset().value_or(0);
135 if (op.getSgprOffset()) {
139 result += *sgprOffset;
141 if (strides.size() != op.getIndices().size())
143 int64_t indexVal = 0;
144 for (
auto pair : llvm::zip(strides, op.getIndices())) {
145 int64_t stride = std::get<0>(pair);
146 Value idx = std::get<1>(pair);
150 indexVal += stride * *idxVal;
156 return result >= bufferType.getNumElements();
160 template <
typename OpType>
161 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
164 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
174 template <
typename OpType>
175 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
178 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
190 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
195 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
198 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
200 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
203 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
205 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
208 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
210 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
213 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
215 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
218 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
220 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
228 Type sourceAType = getSourceA().getType();
229 Type destType = getDestC().getType();
231 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
232 VectorType destVectorType = dyn_cast<VectorType>(destType);
234 Type sourceAElemType = sourceVectorAType.getElementType();
235 Type destElemType = destVectorType.getElementType();
237 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
239 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
242 if (isDestFloat && !isSrcFloat) {
243 return emitOpError(
"Expected float sources with float destination");
246 if (!isDestFloat && isSrcFloat) {
247 return emitOpError(
"Expected int sources with int destination");
257 constexpr uint32_t waveSize = 64;
260 Type sourceType = getSourceA().getType();
261 Type destType = getDestC().getType();
263 Type sourceElem = sourceType, destElem = destType;
264 uint32_t sourceLen = 1, destLen = 1;
265 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
266 sourceLen = sourceVector.getNumElements();
267 sourceElem = sourceVector.getElementType();
269 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
270 destLen = destVector.getNumElements();
271 destElem = destVector.getElementType();
274 Type sourceBType = getSourceB().getType();
276 int64_t sourceBLen = 1;
277 Type sourceBElem = sourceBType;
278 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279 sourceBLen = sourceBVector.getNumElements();
280 sourceBElem = sourceBVector.getElementType();
283 return emitOpError(
"expected both source operands to have f8 elements");
284 if (sourceLen != sourceBLen)
286 "expected both f8 source vectors to have the same length");
288 if (sourceType != sourceBType)
290 "expected both non-f8 source operand types to match exactly");
295 sourceElem = b.getI8Type();
299 sourceElem = b.getI8Type();
302 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
303 if (sourceLen != numSourceElems)
304 return emitOpError(
"expected " + Twine(numSourceElems) +
305 " source values for this operation but got " +
308 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
309 if (destLen != numDestElems)
310 return emitOpError(
"expected " + Twine(numDestElems) +
311 " result values for this operation but got " +
314 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
316 "double-precision ops do not support permuting lanes of B");
317 if (destElem.isF64() && getCbsz() != 0)
319 "double-precision ops do not support permuting lanes of A");
320 if (getAbid() >= (1u << getCbsz()))
322 "block ID for permuting A (abid) must be below 2 ** cbsz");
324 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
326 "negation flags only available for double-precision operations");
335 Type srcType = getSrc().getType();
337 return emitOpError(
"integer and floating point types larger than 64 bits "
338 "are not supported");
341 DPPPerm kind = getKind();
346 case DPPPerm::quad_perm: {
347 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
348 if (!quadPermAttr || quadPermAttr.size() != 4) {
349 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
351 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
352 uint32_t num = elem.getInt();
353 if (num < 0 || num > 3) {
355 "Each element of quad_perm must be in the range [0, 3]");
360 case DPPPerm::row_shl:
361 case DPPPerm::row_shr:
362 case DPPPerm::row_ror: {
364 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
365 "' value not specified");
367 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
368 uint32_t attrValue = intAttr.getInt();
369 if (attrValue < 1 || attrValue > 15) {
370 return emitOpError(
"Attribute value must be between 1 and 15");
375 case DPPPerm::wave_shl:
376 case DPPPerm::wave_shr:
377 case DPPPerm::wave_rol:
378 case DPPPerm::wave_ror:
379 case DPPPerm::row_mirror:
380 case DPPPerm::row_half_mirror:
381 case DPPPerm::row_bcast_15:
382 case DPPPerm::row_bcast_31: {
383 if (permArgument && !isa<UnitAttr>(permArgument)) {
384 return emitOpError(
"Expected unit attribute for permArgument, but found "
385 "non-trivial argument");
393 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
395 #define GET_ATTRDEF_CLASSES
396 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
398 #define GET_OP_CLASSES
399 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FNUZ() const
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...