25 #include "llvm/ADT/TypeSwitch.h"
33 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
35 void AMDGPUDialect::initialize() {
38 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
41 #define GET_ATTRDEF_LIST
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
50 if (getExisting() && getExisting().getType() != getResult().getType())
51 return emitOpError(
"existing values must have same type as result");
56 if (getExisting() && getExisting().getType() != getResult().getType())
57 return emitOpError(
"existing values must have same type as result");
66 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
67 Attribute memorySpace = bufferType.getMemorySpace();
68 bool isGlobal =
false;
71 else if (
auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
72 isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
73 else if (
auto gpuMemorySpace =
74 llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
75 isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
79 "Buffer ops must operate on a memref in global memory");
80 if (!bufferType.hasRank())
82 "Cannot meaningfully buffer_store to an unranked memref");
83 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
84 return op.
emitOpError(
"Expected " + Twine(bufferType.getRank()) +
85 " indices to memref");
118 return cst.getZExtValue();
122 template <
typename OpType>
124 if (!op.getBoundsCheck())
126 MemRefType bufferType = op.getMemref().getType();
127 if (!bufferType.hasStaticShape())
133 int64_t result = offset + op.getIndexOffset().value_or(0);
134 if (op.getSgprOffset()) {
138 result += *sgprOffset;
140 if (strides.size() != op.getIndices().size())
142 int64_t indexVal = 0;
143 for (
auto pair : llvm::zip(strides, op.getIndices())) {
144 int64_t stride = std::get<0>(pair);
145 Value idx = std::get<1>(pair);
149 indexVal += stride * *idxVal;
155 return result >= bufferType.getNumElements();
159 template <
typename OpType>
160 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
173 template <
typename OpType>
174 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
189 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
194 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
197 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
199 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
202 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
204 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
207 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
209 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
212 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
214 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
217 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
219 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
227 Type sourceAType = getSourceA().getType();
228 Type destType = getDestC().getType();
230 VectorType sourceVectorAType = sourceAType.
dyn_cast<VectorType>();
231 VectorType destVectorType = destType.
dyn_cast<VectorType>();
233 Type sourceAElemType = sourceVectorAType.getElementType();
234 Type destElemType = destVectorType.getElementType();
238 bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
240 if (isDestFloat && !isSrcFloat) {
241 return emitOpError(
"Expected float sources with float destination");
244 if (!isDestFloat && isSrcFloat) {
245 return emitOpError(
"Expected int sources with int destination");
255 constexpr uint32_t waveSize = 64;
258 Type sourceType = getSourceA().getType();
259 Type destType = getDestC().getType();
261 Type sourceElem = sourceType, destElem = destType;
262 uint32_t sourceLen = 1, destLen = 1;
263 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
264 sourceLen = sourceVector.getNumElements();
265 sourceElem = sourceVector.getElementType();
267 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
268 destLen = destVector.getNumElements();
269 destElem = destVector.getElementType();
272 Type sourceBType = getSourceB().getType();
274 int64_t sourceBLen = 1;
275 Type sourceBElem = sourceBType;
276 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
277 sourceBLen = sourceBVector.getNumElements();
278 sourceBElem = sourceBVector.getElementType();
281 return emitOpError(
"expected both source operands to have f8 elements");
282 if (sourceLen != sourceBLen)
284 "expected both f8 source vectors to have the same length");
286 if (sourceType != sourceBType)
288 "expected both non-f8 source operand types to match exactly");
293 sourceElem = b.getI8Type();
297 sourceElem = b.getI8Type();
300 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
301 if (sourceLen != numSourceElems)
302 return emitOpError(
"expected " + Twine(numSourceElems) +
303 " source values for this operation but got " +
306 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
307 if (destLen != numDestElems)
308 return emitOpError(
"expected " + Twine(numDestElems) +
309 " result values for this operation but got " +
312 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
314 "double-precision ops do not support permuting lanes of B");
315 if (destElem.isF64() && getCbsz() != 0)
317 "double-precision ops do not support permuting lanes of A");
318 if (getAbid() >= (1u << getCbsz()))
320 "block ID for permuting A (abid) must be below 2 ** cbsz");
322 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
324 "negation flags only available for double-precision operations");
329 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
331 #define GET_ATTRDEF_CLASSES
332 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
334 #define GET_OP_CLASSES
335 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isFloat8E4M3FNUZ() const
bool isFloat8E5M2FNUZ() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...