27 #include "llvm/ADT/TypeSwitch.h"
35 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
37 void AMDGPUDialect::initialize() {
40 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
52 if (getExisting() && getExisting().
getType() != getResult().
getType())
53 return emitOpError(
"existing values must have same type as result");
58 if (getExisting() && getExisting().
getType() != getResult().
getType())
59 return emitOpError(
"existing values must have same type as result");
67 if (getExisting() && getExisting().
getType() != getResult().
getType())
68 return emitOpError(
"existing values must have same type as result");
86 MemRefLayoutAttrInterface layout = source.getLayout();
87 if (resetOffset && !layout.isIdentity()) {
88 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
93 return (MemRefType)(mb);
96 LogicalResult FatRawBufferCastOp::inferReturnTypes(
100 Adaptor adaptor(operands, attributes, properties, regions);
102 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
105 FailureOr<MemRefType> resultType =
107 if (failed(resultType))
114 FailureOr<MemRefType> expectedResultType =
116 if (failed(expectedResultType))
117 return emitOpError(
"source type ")
118 << getSource().getType() <<
" can't have its offset reset";
119 if (getResult().
getType() != *expectedResultType)
120 return emitOpError(
"expected result type to be ")
121 << *expectedResultType <<
" but got " << getResult().getType();
128 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
129 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
130 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
131 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
136 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
137 return intMemorySpace.getInt() == 3;
138 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
139 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
144 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
145 return intMemorySpace.getInt() == 7;
146 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
147 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
154 template <
typename T>
156 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
160 return op.emitOpError(
161 "Buffer ops must operate on a memref in global memory");
162 if (!bufferType.hasRank())
163 return op.emitOpError(
164 "Cannot meaningfully buffer_store to an unranked memref");
165 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
166 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
167 " indices to memref");
200 return cst.getZExtValue();
204 template <
typename OpType>
206 if (!op.getBoundsCheck())
208 MemRefType bufferType = op.getMemref().getType();
209 if (!bufferType.hasStaticShape())
213 if (failed(bufferType.getStridesAndOffset(strides, offset)))
215 int64_t result = offset + op.getIndexOffset().value_or(0);
216 if (op.getSgprOffset()) {
220 result += *sgprOffset;
222 if (strides.size() != op.getIndices().size())
224 int64_t indexVal = 0;
225 for (
auto pair : llvm::zip(strides, op.getIndices())) {
226 int64_t stride = std::get<0>(pair);
227 Value idx = std::get<1>(pair);
231 indexVal += stride * *idxVal;
237 return result >= bufferType.getNumElements();
241 template <
typename OpType>
242 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
245 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
248 Type loadType = op.getResult().getType();
255 template <
typename OpType>
256 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
259 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
271 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
276 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
279 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
281 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
284 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
286 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
289 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
291 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
294 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
296 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
299 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
301 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
309 Type sourceAType = getSourceA().getType();
310 Type sourceBType = getSourceB().getType();
311 Type destType = getDestC().getType();
313 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
314 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
315 VectorType destVectorType = dyn_cast<VectorType>(destType);
317 Type sourceAElemType = sourceVectorAType.getElementType();
318 Type sourceBElemType = sourceVectorBType.getElementType();
319 Type destElemType = destVectorType.getElementType();
321 if (sourceVectorAType.getNumElements() !=
322 sourceVectorBType.getNumElements()) {
323 return emitOpError(
"source vectors have different lengths: ")
324 << sourceVectorAType <<
" vs. " << sourceVectorBType;
327 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
329 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
332 if (isDestFloat && !isSrcFloat) {
333 return emitOpError(
"Expected float sources with float destination");
336 if (!isDestFloat && isSrcFloat) {
337 return emitOpError(
"Expected int sources with int destination");
340 if (sourceAElemType != sourceBElemType &&
341 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
342 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
344 "source element types much match (except for fp8) but have ")
345 << sourceAType <<
" and " << sourceBType;
354 constexpr uint32_t waveSize = 64;
357 Type sourceType = getSourceA().getType();
358 Type destType = getDestC().getType();
360 Type sourceElem = sourceType, destElem = destType;
361 uint32_t sourceLen = 1, destLen = 1;
362 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
363 sourceLen = sourceVector.getNumElements();
364 sourceElem = sourceVector.getElementType();
366 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
367 destLen = destVector.getNumElements();
368 destElem = destVector.getElementType();
371 Type sourceBType = getSourceB().getType();
373 int64_t sourceBLen = 1;
374 Type sourceBElem = sourceBType;
375 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
376 sourceBLen = sourceBVector.getNumElements();
377 sourceBElem = sourceBVector.getElementType();
381 return emitOpError(
"expected both source operands to have small-float "
382 "elements if one does");
383 if (sourceLen != sourceBLen)
385 "expected both small-float source vectors to have the same length");
387 if (sourceType != sourceBType)
388 return emitOpError(
"expected both non-small-float source operand types "
394 sourceElem = b.getI8Type();
398 sourceElem = b.getI8Type();
401 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
402 if (sourceLen != numSourceElems)
403 return emitOpError(
"expected " + Twine(numSourceElems) +
404 " source values for this operation but got " +
407 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
408 if (destLen != numDestElems)
409 return emitOpError(
"expected " + Twine(numDestElems) +
410 " result values for this operation but got " +
413 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
415 "double-precision ops do not support permuting lanes of B");
416 if (destElem.isF64() && getCbsz() != 0)
418 "double-precision ops do not support permuting lanes of A");
419 if (getAbid() >= (1u << getCbsz()))
421 "block ID for permuting A (abid) must be below 2 ** cbsz");
423 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
425 "negation flags only available for double-precision operations");
434 Type srcType = getSrc().getType();
436 return emitOpError(
"integer and floating point types larger than 64 bits "
437 "are not supported");
440 DPPPerm
kind = getKind();
445 case DPPPerm::quad_perm: {
446 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
447 if (!quadPermAttr || quadPermAttr.size() != 4) {
448 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
450 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
451 int32_t num = elem.getInt();
452 if (num < 0 || num > 3) {
454 "Each element of quad_perm must be in the range [0, 3]");
459 case DPPPerm::row_shl:
460 case DPPPerm::row_shr:
461 case DPPPerm::row_ror: {
463 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
464 "' value not specified");
466 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
467 uint32_t attrValue = intAttr.getInt();
468 if (attrValue < 1 || attrValue > 15) {
469 return emitOpError(
"Attribute value must be between 1 and 15");
474 case DPPPerm::wave_shl:
475 case DPPPerm::wave_shr:
476 case DPPPerm::wave_rol:
477 case DPPPerm::wave_ror:
478 case DPPPerm::row_mirror:
479 case DPPPerm::row_half_mirror:
480 case DPPPerm::row_bcast_15:
481 case DPPPerm::row_bcast_31: {
482 if (permArgument && !isa<UnitAttr>(permArgument)) {
483 return emitOpError(
"Expected unit attribute for permArgument, but found "
484 "non-trivial argument");
493 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
494 MemRefType dstType = cast<MemRefType>(getDst().
getType());
496 if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
497 return emitOpError(
"destination types must be contiguous");
499 auto elemType = srcType.getElementType();
501 if (elemType != dstType.getElementType())
502 return emitOpError(
"source and destination element types must match");
505 auto transferType = getTransferType();
507 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
508 transferSize = vectorTransfer.getNumElements() *
509 vectorTransfer.getElementTypeBitWidth();
511 transferSize = transferType.getIntOrFloatBitWidth();
513 if (transferSize != 8 && transferSize != 16 && transferSize != 32)
514 return emitOpError(
"Transfering type size must be 8, 16, or 32 bits");
519 "source memory address space must be global or fat raw buffer");
522 return emitOpError(
"destination memory address space must be Workgroup");
527 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
529 #define GET_ATTRDEF_CLASSES
530 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
532 #define GET_OP_CLASSES
533 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...