28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
44 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
46 assert((isa<spirv::CooperativeMatrixType>(coopType)));
48 switch (op.getOpType()) {
49 case gpu::MMAElementwiseOp::ADDF:
55 case gpu::MMAElementwiseOp::SUBF:
61 case gpu::MMAElementwiseOp::DIVF:
64 case gpu::MMAElementwiseOp::DIVS:
67 case gpu::MMAElementwiseOp::DIVU:
70 case gpu::MMAElementwiseOp::NEGATEF:
73 case gpu::MMAElementwiseOp::NEGATES:
76 case gpu::MMAElementwiseOp::EXTF:
86 assert(!operands.empty());
88 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
91 return isa<spirv::CooperativeMatrixType>(operands.front().
getType());
97 struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter)
const override {
104 assert(adaptor.getOperands().size() == 1);
105 Value cst = adaptor.getOperands().front();
106 auto coopType = getTypeConverter()->convertType(op.getType());
108 return rewriter.notifyMatchFailure(op,
"type conversion failed");
110 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
117 struct WmmaElementwiseOpToSPIRVDefaultLowering final
118 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
122 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter)
const override {
126 return rewriter.notifyMatchFailure(op,
127 "not all operands are coop matrices");
130 auto coopType = getTypeConverter()->convertType(op.getType());
132 return rewriter.notifyMatchFailure(op,
"type conversion failed");
141 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
142 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
146 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const override {
148 if (adaptor.getOperands().size() != 2)
153 return rewriter.notifyMatchFailure(op,
154 "not all operands are coop matrices");
157 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
162 Value lhs = op.getOperands().front();
163 Value rhs = op.getOperands().back();
164 Value splat =
nullptr;
165 Value matrix =
nullptr;
166 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
167 splat = adaptor.getOperands().front();
168 matrix = adaptor.getOperands().back();
169 }
else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
170 matrix = adaptor.getOperands().front();
171 splat = adaptor.getOperands().back();
173 if (!splat || !matrix)
174 return rewriter.notifyMatchFailure(op,
"no splat operand");
178 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
180 return rewriter.notifyMatchFailure(op,
181 "splat is not a composite construct");
184 assert(cc.getConstituents().size() == 1);
185 scalar = cc.getConstituents().front();
187 auto coopType = getTypeConverter()->convertType(op.getType());
189 return rewriter.notifyMatchFailure(op,
"type conversion failed");
190 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
191 op, coopType, ValueRange{matrix, scalar});
206 struct WmmaLoadOpToSPIRVLowering final
211 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
213 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
216 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
217 MemRefType memrefType = op.getSrcMemref().getType();
220 adaptor.getIndices(), loc, rewriter);
227 int64_t stride = op.getLeadDimension().getSExtValue();
229 auto strideValue = rewriter.
create<spirv::ConstantOp>(
232 bool isColMajor = op.getTranspose().value_or(
false);
233 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
234 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
237 op, coopType, bufferPtr, strideValue, layout);
244 struct WmmaStoreOpToSPIRVLowering final
249 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
251 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
254 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
257 adaptor.getIndices(), loc, rewriter);
259 int64_t stride = op.getLeadDimension().getSExtValue();
261 auto strideValue = rewriter.
create<spirv::ConstantOp>(
264 bool isColMajor = op.getTranspose().value_or(
false);
265 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
266 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
269 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
276 struct WmmaMmaOpToSPIRVLowering final
281 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
285 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
297 using namespace mlir;
299 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
300 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
301 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
303 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
314 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
315 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
316 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
320 spirv::Scope::Subgroup, use);
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...