27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
38 void AMDGPUDialect::initialize() {
41 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
44 #define GET_ATTRDEF_LIST
45 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
53 if (getExisting() && getExisting().
getType() != getResult().
getType())
54 return emitOpError(
"existing values must have same type as result");
59 if (getExisting() && getExisting().
getType() != getResult().
getType())
60 return emitOpError(
"existing values must have same type as result");
68 if (getExisting() && getExisting().
getType() != getResult().
getType())
69 return emitOpError(
"existing values must have same type as result");
87 MemRefLayoutAttrInterface layout = source.getLayout();
88 if (resetOffset && !layout.isIdentity()) {
89 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
94 return (MemRefType)(mb);
97 LogicalResult FatRawBufferCastOp::inferReturnTypes(
101 Adaptor adaptor(operands, attributes, properties, regions);
103 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
106 FailureOr<MemRefType> resultType =
108 if (failed(resultType))
115 FailureOr<MemRefType> expectedResultType =
117 if (failed(expectedResultType))
118 return emitOpError(
"source type ")
119 << getSource().getType() <<
" can't have its offset reset";
120 if (getResult().
getType() != *expectedResultType)
121 return emitOpError(
"expected result type to be ")
122 << *expectedResultType <<
" but got " << getResult().getType();
129 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
130 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
131 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
132 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
137 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
138 return intMemorySpace.getInt() == 3;
139 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
140 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
145 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
146 return intMemorySpace.getInt() == 7;
147 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
148 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
155 template <
typename T>
157 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
161 return op.emitOpError(
162 "Buffer ops must operate on a memref in global memory");
163 if (!bufferType.hasRank())
164 return op.emitOpError(
165 "Cannot meaningfully buffer_store to an unranked memref");
166 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
167 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
168 " indices to memref");
201 return cst.getZExtValue();
205 template <
typename OpType>
207 if (!op.getBoundsCheck())
209 MemRefType bufferType = op.getMemref().getType();
210 if (!bufferType.hasStaticShape())
214 if (failed(bufferType.getStridesAndOffset(strides, offset)))
216 int64_t result = offset + op.getIndexOffset().value_or(0);
217 if (op.getSgprOffset()) {
221 result += *sgprOffset;
223 if (strides.size() != op.getIndices().size())
225 int64_t indexVal = 0;
226 for (
auto pair : llvm::zip(strides, op.getIndices())) {
227 int64_t stride = std::get<0>(pair);
228 Value idx = std::get<1>(pair);
232 indexVal += stride * *idxVal;
238 return result >= bufferType.getNumElements();
242 template <
typename OpType>
243 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
246 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
249 Type loadType = op.getResult().getType();
256 template <
typename OpType>
257 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
260 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
272 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
277 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
280 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
282 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
285 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
287 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
290 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
292 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
295 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
297 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
300 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
302 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
310 Type sourceAType = getSourceA().getType();
311 Type sourceBType = getSourceB().getType();
312 Type destType = getDestC().getType();
314 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
315 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
316 VectorType destVectorType = dyn_cast<VectorType>(destType);
318 Type sourceAElemType = sourceVectorAType.getElementType();
319 Type sourceBElemType = sourceVectorBType.getElementType();
320 Type destElemType = destVectorType.getElementType();
322 if (sourceVectorAType.getNumElements() !=
323 sourceVectorBType.getNumElements()) {
324 return emitOpError(
"source vectors have different lengths: ")
325 << sourceVectorAType <<
" vs. " << sourceVectorBType;
328 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
330 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
333 if (isDestFloat && !isSrcFloat) {
334 return emitOpError(
"Expected float sources with float destination");
337 if (!isDestFloat && isSrcFloat) {
338 return emitOpError(
"Expected int sources with int destination");
341 if (sourceAElemType != sourceBElemType &&
342 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
343 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
345 "source element types much match (except for fp8) but have ")
346 << sourceAType <<
" and " << sourceBType;
355 constexpr uint32_t waveSize = 64;
358 Type sourceType = getSourceA().getType();
359 Type destType = getDestC().getType();
361 Type sourceElem = sourceType, destElem = destType;
362 uint32_t sourceLen = 1, destLen = 1;
363 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
364 sourceLen = sourceVector.getNumElements();
365 sourceElem = sourceVector.getElementType();
367 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
368 destLen = destVector.getNumElements();
369 destElem = destVector.getElementType();
372 Type sourceBType = getSourceB().getType();
374 int64_t sourceBLen = 1;
375 Type sourceBElem = sourceBType;
376 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
377 sourceBLen = sourceBVector.getNumElements();
378 sourceBElem = sourceBVector.getElementType();
382 return emitOpError(
"expected both source operands to have small-float "
383 "elements if one does");
384 if (sourceLen != sourceBLen)
386 "expected both small-float source vectors to have the same length");
388 if (sourceType != sourceBType)
389 return emitOpError(
"expected both non-small-float source operand types "
395 sourceElem = b.getI8Type();
399 sourceElem = b.getI8Type();
402 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
403 if (sourceLen != numSourceElems)
404 return emitOpError(
"expected " + Twine(numSourceElems) +
405 " source values for this operation but got " +
408 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
409 if (destLen != numDestElems)
410 return emitOpError(
"expected " + Twine(numDestElems) +
411 " result values for this operation but got " +
414 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
416 "double-precision ops do not support permuting lanes of B");
417 if (destElem.isF64() && getCbsz() != 0)
419 "double-precision ops do not support permuting lanes of A");
420 if (getAbid() >= (1u << getCbsz()))
422 "block ID for permuting A (abid) must be below 2 ** cbsz");
424 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
426 "negation flags only available for double-precision operations");
435 Type srcType = getSrc().getType();
437 return emitOpError(
"integer and floating point types larger than 64 bits "
438 "are not supported");
441 DPPPerm
kind = getKind();
446 case DPPPerm::quad_perm: {
447 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
448 if (!quadPermAttr || quadPermAttr.size() != 4) {
449 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
451 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
452 int32_t num = elem.getInt();
453 if (num < 0 || num > 3) {
455 "Each element of quad_perm must be in the range [0, 3]");
460 case DPPPerm::row_shl:
461 case DPPPerm::row_shr:
462 case DPPPerm::row_ror: {
464 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
465 "' value not specified");
467 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
468 uint32_t attrValue = intAttr.getInt();
469 if (attrValue < 1 || attrValue > 15) {
470 return emitOpError(
"Attribute value must be between 1 and 15");
475 case DPPPerm::wave_shl:
476 case DPPPerm::wave_shr:
477 case DPPPerm::wave_rol:
478 case DPPPerm::wave_ror:
479 case DPPPerm::row_mirror:
480 case DPPPerm::row_half_mirror:
481 case DPPPerm::row_bcast_15:
482 case DPPPerm::row_bcast_31: {
483 if (permArgument && !isa<UnitAttr>(permArgument)) {
484 return emitOpError(
"Expected unit attribute for permArgument, but found "
485 "non-trivial argument");
494 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
495 MemRefType dstType = cast<MemRefType>(getDst().
getType());
497 if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
498 return emitOpError(
"destination types must be contiguous");
500 auto elemType = srcType.getElementType();
502 if (elemType != dstType.getElementType())
503 return emitOpError(
"source and destination element types must match");
506 auto transferType = getTransferType();
508 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
509 transferSize = vectorTransfer.getNumElements() *
510 vectorTransfer.getElementTypeBitWidth();
512 transferSize = transferType.getIntOrFloatBitWidth();
514 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
516 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
521 "source memory address space must be global or fat raw buffer");
524 return emitOpError(
"destination memory address space must be Workgroup");
530 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
533 return emitOpError(
"source memory address space must be Workgroup");
535 auto transferType = cast<VectorType>(
getType());
536 size_t numElements = transferType.getNumElements();
537 size_t elementTypeSize =
538 transferType.getElementType().getIntOrFloatBitWidth();
541 const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
548 auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
549 if (validNumElems == KValidLoadSizeMap.end()) {
550 return emitOpError(
"Unsupported element type size for transpose load: ")
551 << elementTypeSize <<
" bits";
553 if (numElements != validNumElems->second) {
555 "Transferring type size mismatch: expected num of elements: ")
556 << validNumElems->second;
562 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
564 #define GET_ATTRDEF_CLASSES
565 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
567 #define GET_OP_CLASSES
568 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1218::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...