27 #include "llvm/ADT/TypeSwitch.h"
35 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
37 void AMDGPUDialect::initialize() {
40 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
52 if (getExisting() && getExisting().
getType() != getResult().
getType())
53 return emitOpError(
"existing values must have same type as result");
58 if (getExisting() && getExisting().
getType() != getResult().
getType())
59 return emitOpError(
"existing values must have same type as result");
77 MemRefLayoutAttrInterface layout = source.getLayout();
78 if (resetOffset && !layout.isIdentity()) {
79 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
84 return (MemRefType)(mb);
87 LogicalResult FatRawBufferCastOp::inferReturnTypes(
91 Adaptor adaptor(operands, attributes, properties, regions);
93 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
96 FailureOr<MemRefType> resultType =
98 if (failed(resultType))
105 FailureOr<MemRefType> expectedResultType =
107 if (failed(expectedResultType))
108 return emitOpError(
"source type ")
109 << getSource().getType() <<
" can't have its offset reset";
110 if (getResult().
getType() != *expectedResultType)
111 return emitOpError(
"expected result type to be ")
112 << *expectedResultType <<
" but got " << getResult().getType();
119 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
120 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
121 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
122 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
127 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
128 return intMemorySpace.getInt() == 3;
129 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
130 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
137 template <
typename T>
139 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
143 return op.emitOpError(
144 "Buffer ops must operate on a memref in global memory");
145 if (!bufferType.hasRank())
146 return op.emitOpError(
147 "Cannot meaningfully buffer_store to an unranked memref");
148 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
149 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
150 " indices to memref");
183 return cst.getZExtValue();
187 template <
typename OpType>
189 if (!op.getBoundsCheck())
191 MemRefType bufferType = op.getMemref().getType();
192 if (!bufferType.hasStaticShape())
196 if (failed(bufferType.getStridesAndOffset(strides, offset)))
198 int64_t result = offset + op.getIndexOffset().value_or(0);
199 if (op.getSgprOffset()) {
203 result += *sgprOffset;
205 if (strides.size() != op.getIndices().size())
207 int64_t indexVal = 0;
208 for (
auto pair : llvm::zip(strides, op.getIndices())) {
209 int64_t stride = std::get<0>(pair);
210 Value idx = std::get<1>(pair);
214 indexVal += stride * *idxVal;
220 return result >= bufferType.getNumElements();
224 template <
typename OpType>
225 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
228 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
231 Type loadType = op.getResult().getType();
238 template <
typename OpType>
239 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
242 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
254 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
259 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
262 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
264 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
267 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
269 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
272 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
274 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
277 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
279 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
282 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
284 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
292 Type sourceAType = getSourceA().getType();
293 Type sourceBType = getSourceB().getType();
294 Type destType = getDestC().getType();
296 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
297 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
298 VectorType destVectorType = dyn_cast<VectorType>(destType);
300 Type sourceAElemType = sourceVectorAType.getElementType();
301 Type sourceBElemType = sourceVectorBType.getElementType();
302 Type destElemType = destVectorType.getElementType();
304 if (sourceVectorAType.getNumElements() !=
305 sourceVectorBType.getNumElements()) {
306 return emitOpError(
"source vectors have different lengths: ")
307 << sourceVectorAType <<
" vs. " << sourceVectorBType;
310 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
312 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
315 if (isDestFloat && !isSrcFloat) {
316 return emitOpError(
"Expected float sources with float destination");
319 if (!isDestFloat && isSrcFloat) {
320 return emitOpError(
"Expected int sources with int destination");
323 if (sourceAElemType != sourceBElemType &&
324 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
325 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
327 "source element types much match (except for fp8) but have ")
328 << sourceAType <<
" and " << sourceBType;
337 constexpr uint32_t waveSize = 64;
340 Type sourceType = getSourceA().getType();
341 Type destType = getDestC().getType();
343 Type sourceElem = sourceType, destElem = destType;
344 uint32_t sourceLen = 1, destLen = 1;
345 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
346 sourceLen = sourceVector.getNumElements();
347 sourceElem = sourceVector.getElementType();
349 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
350 destLen = destVector.getNumElements();
351 destElem = destVector.getElementType();
354 Type sourceBType = getSourceB().getType();
356 int64_t sourceBLen = 1;
357 Type sourceBElem = sourceBType;
358 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
359 sourceBLen = sourceBVector.getNumElements();
360 sourceBElem = sourceBVector.getElementType();
364 return emitOpError(
"expected both source operands to have small-float "
365 "elements if one does");
366 if (sourceLen != sourceBLen)
368 "expected both small-float source vectors to have the same length");
370 if (sourceType != sourceBType)
371 return emitOpError(
"expected both non-small-float source operand types "
377 sourceElem = b.getI8Type();
381 sourceElem = b.getI8Type();
384 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
385 if (sourceLen != numSourceElems)
386 return emitOpError(
"expected " + Twine(numSourceElems) +
387 " source values for this operation but got " +
390 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
391 if (destLen != numDestElems)
392 return emitOpError(
"expected " + Twine(numDestElems) +
393 " result values for this operation but got " +
396 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
398 "double-precision ops do not support permuting lanes of B");
399 if (destElem.isF64() && getCbsz() != 0)
401 "double-precision ops do not support permuting lanes of A");
402 if (getAbid() >= (1u << getCbsz()))
404 "block ID for permuting A (abid) must be below 2 ** cbsz");
406 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
408 "negation flags only available for double-precision operations");
417 Type srcType = getSrc().getType();
419 return emitOpError(
"integer and floating point types larger than 64 bits "
420 "are not supported");
423 DPPPerm
kind = getKind();
428 case DPPPerm::quad_perm: {
429 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
430 if (!quadPermAttr || quadPermAttr.size() != 4) {
431 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
433 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
434 int32_t num = elem.getInt();
435 if (num < 0 || num > 3) {
437 "Each element of quad_perm must be in the range [0, 3]");
442 case DPPPerm::row_shl:
443 case DPPPerm::row_shr:
444 case DPPPerm::row_ror: {
446 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
447 "' value not specified");
449 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
450 uint32_t attrValue = intAttr.getInt();
451 if (attrValue < 1 || attrValue > 15) {
452 return emitOpError(
"Attribute value must be between 1 and 15");
457 case DPPPerm::wave_shl:
458 case DPPPerm::wave_shr:
459 case DPPPerm::wave_rol:
460 case DPPPerm::wave_ror:
461 case DPPPerm::row_mirror:
462 case DPPPerm::row_half_mirror:
463 case DPPPerm::row_bcast_15:
464 case DPPPerm::row_bcast_31: {
465 if (permArgument && !isa<UnitAttr>(permArgument)) {
466 return emitOpError(
"Expected unit attribute for permArgument, but found "
467 "non-trivial argument");
476 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
477 MemRefType dstType = cast<MemRefType>(getDst().
getType());
481 "destination types must have static shape and contiguous");
483 auto elemType = srcType.getElementType();
485 if (elemType != dstType.getElementType())
486 return emitOpError(
"source and destination element types must match");
489 auto transferType = getTransferType();
491 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
492 transferSize = vectorTransfer.getNumElements() *
493 vectorTransfer.getElementTypeBitWidth();
495 transferSize = transferType.getIntOrFloatBitWidth();
497 if (transferSize != 8 && transferSize != 16 && transferSize != 32)
498 return emitOpError(
"Transfering type size must be 8, 16, or 32 bits");
501 return emitOpError(
"source memory address space must be Global");
504 return emitOpError(
"destination memory address space must be Workgroup");
509 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
511 #define GET_ATTRDEF_CLASSES
512 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
514 #define GET_OP_CLASSES
515 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...