28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
44 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
46 assert((isa<spirv::CooperativeMatrixType>(coopType)));
48 switch (op.getOpType()) {
49 case gpu::MMAElementwiseOp::ADDF:
55 case gpu::MMAElementwiseOp::SUBF:
61 case gpu::MMAElementwiseOp::DIVF:
64 case gpu::MMAElementwiseOp::DIVS:
67 case gpu::MMAElementwiseOp::DIVU:
70 case gpu::MMAElementwiseOp::NEGATEF:
73 case gpu::MMAElementwiseOp::NEGATES:
76 case gpu::MMAElementwiseOp::EXTF:
86 assert(!operands.empty());
88 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
91 return isa<spirv::CooperativeMatrixType>(operands.front().
getType());
97 struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter)
const override {
104 Value cst = llvm::getSingleElement(adaptor.getOperands());
105 auto coopType = getTypeConverter()->convertType(op.getType());
107 return rewriter.notifyMatchFailure(op,
"type conversion failed");
109 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
116 struct WmmaExtractOpToSPIRVLowering final
117 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
121 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const override {
123 Value matrix = adaptor.getMatrix();
125 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
128 return rewriter.notifyMatchFailure(op,
"type conversion failed");
130 SmallVector<int32_t> intValues;
131 for (Value val : op.getIndices()) {
132 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
135 return rewriter.notifyMatchFailure(op,
"indices must be constants");
139 Type elementType = coopType.getElementType();
140 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
148 struct WmmaInsertOpToSPIRVLowering final
149 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
153 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter)
const override {
155 Value value = adaptor.getValue();
156 Value matrix = adaptor.getMatrix();
157 auto coopType = getTypeConverter()->convertType(matrix.getType());
159 return rewriter.notifyMatchFailure(op,
"type conversion failed");
161 SmallVector<int32_t> intValues;
162 for (Value val : op.getIndices()) {
163 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
166 return rewriter.notifyMatchFailure(op,
"indices must be constants");
170 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
178 struct WmmaElementwiseOpToSPIRVDefaultLowering final
179 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
183 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter)
const override {
187 return rewriter.notifyMatchFailure(op,
188 "not all operands are coop matrices");
191 auto coopType = getTypeConverter()->convertType(op.getType());
193 return rewriter.notifyMatchFailure(op,
"type conversion failed");
202 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
203 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
207 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter)
const override {
209 if (adaptor.getOperands().size() != 2)
214 return rewriter.notifyMatchFailure(op,
215 "not all operands are coop matrices");
218 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
223 Value lhs = op.getOperands().front();
224 Value rhs = op.getOperands().back();
225 Value splat =
nullptr;
226 Value matrix =
nullptr;
227 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228 splat = adaptor.getOperands().front();
229 matrix = adaptor.getOperands().back();
230 }
else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
231 matrix = adaptor.getOperands().front();
232 splat = adaptor.getOperands().back();
234 if (!splat || !matrix)
235 return rewriter.notifyMatchFailure(op,
"no splat operand");
239 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
241 return rewriter.notifyMatchFailure(op,
242 "splat is not a composite construct");
245 scalar = llvm::getSingleElement(cc.getConstituents());
247 auto coopType = getTypeConverter()->convertType(op.getType());
249 return rewriter.notifyMatchFailure(op,
"type conversion failed");
250 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
251 op, coopType, ValueRange{matrix, scalar});
266 struct WmmaLoadOpToSPIRVLowering final
271 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
273 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
276 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
277 MemRefType memrefType = op.getSrcMemref().getType();
280 adaptor.getIndices(), loc, rewriter);
287 int64_t stride = op.getLeadDimension().getSExtValue();
289 auto strideValue = rewriter.
create<spirv::ConstantOp>(
292 bool isColMajor = op.getTranspose().value_or(
false);
293 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
294 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
297 op, coopType, bufferPtr, strideValue, layout);
304 struct WmmaStoreOpToSPIRVLowering final
309 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
311 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
314 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
317 adaptor.getIndices(), loc, rewriter);
319 int64_t stride = op.getLeadDimension().getSExtValue();
321 auto strideValue = rewriter.
create<spirv::ConstantOp>(
324 bool isColMajor = op.getTranspose().value_or(
false);
325 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
326 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
329 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
336 struct WmmaMmaOpToSPIRVLowering final
341 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
345 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
357 using namespace mlir;
359 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
360 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
362 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
364 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
375 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
376 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
377 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
381 spirv::Scope::Subgroup, use);
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
@ Type
An inlay hint that for a type annotation.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...