28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
44 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
46 assert((isa<spirv::CooperativeMatrixType>(coopType)));
48 switch (op.getOpType()) {
49 case gpu::MMAElementwiseOp::ADDF:
55 case gpu::MMAElementwiseOp::SUBF:
61 case gpu::MMAElementwiseOp::DIVF:
64 case gpu::MMAElementwiseOp::DIVS:
67 case gpu::MMAElementwiseOp::DIVU:
70 case gpu::MMAElementwiseOp::NEGATEF:
73 case gpu::MMAElementwiseOp::NEGATES:
76 case gpu::MMAElementwiseOp::EXTF:
86 assert(!operands.empty());
88 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
91 return isa<spirv::CooperativeMatrixType>(operands.front().
getType());
97 struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter)
const override {
104 Value cst = llvm::getSingleElement(adaptor.getOperands());
105 auto coopType = getTypeConverter()->convertType(op.getType());
107 return rewriter.notifyMatchFailure(op,
"type conversion failed");
109 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
116 struct WmmaElementwiseOpToSPIRVDefaultLowering final
117 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
121 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const override {
125 return rewriter.notifyMatchFailure(op,
126 "not all operands are coop matrices");
129 auto coopType = getTypeConverter()->convertType(op.getType());
131 return rewriter.notifyMatchFailure(op,
"type conversion failed");
140 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
141 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
145 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter)
const override {
147 if (adaptor.getOperands().size() != 2)
152 return rewriter.notifyMatchFailure(op,
153 "not all operands are coop matrices");
156 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
161 Value lhs = op.getOperands().front();
162 Value rhs = op.getOperands().back();
163 Value splat =
nullptr;
164 Value matrix =
nullptr;
165 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
166 splat = adaptor.getOperands().front();
167 matrix = adaptor.getOperands().back();
168 }
else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
169 matrix = adaptor.getOperands().front();
170 splat = adaptor.getOperands().back();
172 if (!splat || !matrix)
173 return rewriter.notifyMatchFailure(op,
"no splat operand");
177 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
179 return rewriter.notifyMatchFailure(op,
180 "splat is not a composite construct");
183 scalar = llvm::getSingleElement(cc.getConstituents());
185 auto coopType = getTypeConverter()->convertType(op.getType());
187 return rewriter.notifyMatchFailure(op,
"type conversion failed");
188 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
189 op, coopType, ValueRange{matrix, scalar});
204 struct WmmaLoadOpToSPIRVLowering final
209 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
211 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
214 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
215 MemRefType memrefType = op.getSrcMemref().getType();
218 adaptor.getIndices(), loc, rewriter);
225 int64_t stride = op.getLeadDimension().getSExtValue();
227 auto strideValue = rewriter.
create<spirv::ConstantOp>(
230 bool isColMajor = op.getTranspose().value_or(
false);
231 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
232 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
235 op, coopType, bufferPtr, strideValue, layout);
242 struct WmmaStoreOpToSPIRVLowering final
247 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
249 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
252 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
255 adaptor.getIndices(), loc, rewriter);
257 int64_t stride = op.getLeadDimension().getSExtValue();
259 auto strideValue = rewriter.
create<spirv::ConstantOp>(
262 bool isColMajor = op.getTranspose().value_or(
false);
263 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
264 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
267 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
274 struct WmmaMmaOpToSPIRVLowering final
279 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
283 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
295 using namespace mlir;
297 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
298 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
299 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
301 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
312 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
313 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
314 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
318 spirv::Scope::Subgroup, use);
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...