25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringSwitch.h"
41 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
47 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
49 case gpu::MMAElementwiseOp::ADDI:
50 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
52 case gpu::MMAElementwiseOp::SUBF:
53 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
55 case gpu::MMAElementwiseOp::SUBI:
56 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
58 case gpu::MMAElementwiseOp::MULF:
59 builder.replaceOpWithNewOp<spirv::FMulOp>(op, coopType, operands);
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
76 case gpu::MMAElementwiseOp::EXTF:
77 case gpu::MMAElementwiseOp::TRUNCF:
78 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
87 assert(!operands.empty());
89 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
92 return isa<spirv::CooperativeMatrixType>(operands.front().
getType());
98struct WmmaConstantOpToSPIRVLowering final
99 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
103 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter)
const override {
105 Value cst = llvm::getSingleElement(adaptor.getOperands());
106 auto coopType = getTypeConverter()->convertType(op.getType());
108 return rewriter.notifyMatchFailure(op,
"type conversion failed");
110 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
117struct WmmaExtractOpToSPIRVLowering final
118 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
122 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter)
const override {
124 Value matrix = adaptor.getMatrix();
126 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
129 return rewriter.notifyMatchFailure(op,
"type conversion failed");
131 SmallVector<int32_t> intValues;
132 for (Value val : op.getIndices()) {
133 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
134 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
136 return rewriter.notifyMatchFailure(op,
"indices must be constants");
140 Type elementType = coopType.getElementType();
141 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
142 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
149struct WmmaInsertOpToSPIRVLowering final
150 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
154 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter)
const override {
156 Value value = adaptor.getValue();
157 Value matrix = adaptor.getMatrix();
158 auto coopType = getTypeConverter()->convertType(matrix.getType());
160 return rewriter.notifyMatchFailure(op,
"type conversion failed");
162 SmallVector<int32_t> intValues;
163 for (Value val : op.getIndices()) {
164 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
165 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
167 return rewriter.notifyMatchFailure(op,
"indices must be constants");
171 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
172 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
179struct WmmaElementwiseOpToSPIRVDefaultLowering final
180 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
184 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
188 return rewriter.notifyMatchFailure(op,
189 "not all operands are coop matrices");
192 auto coopType = getTypeConverter()->convertType(op.getType());
194 return rewriter.notifyMatchFailure(op,
"type conversion failed");
203struct WmmaElementwiseOpToSPIRVScalarMulLowering final
204 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
208 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter)
const override {
210 if (adaptor.getOperands().size() != 2)
215 return rewriter.notifyMatchFailure(op,
216 "not all operands are coop matrices");
219 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
224 Value
lhs = op.getOperands().front();
225 Value
rhs = op.getOperands().back();
226 Value splat =
nullptr;
227 Value matrix =
nullptr;
228 if (
lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
229 splat = adaptor.getOperands().front();
230 matrix = adaptor.getOperands().back();
231 }
else if (
rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
232 matrix = adaptor.getOperands().front();
233 splat = adaptor.getOperands().back();
235 if (!splat || !matrix)
236 return rewriter.notifyMatchFailure(op,
"no splat operand");
240 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
242 return rewriter.notifyMatchFailure(op,
243 "splat is not a composite construct");
246 scalar = llvm::getSingleElement(cc.getConstituents());
248 auto coopType = getTypeConverter()->convertType(op.getType());
250 return rewriter.notifyMatchFailure(op,
"type conversion failed");
251 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
267struct WmmaLoadOpToSPIRVLowering final
268 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
272 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter)
const override {
274 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
277 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
278 MemRefType memrefType = op.getSrcMemref().getType();
281 adaptor.getIndices(), loc, rewriter);
286 return rewriter.notifyMatchFailure(op,
"type conversion failed");
288 int64_t stride = op.getLeadDimension().getSExtValue();
289 IntegerType i32Type = rewriter.getI32Type();
290 auto strideValue = spirv::ConstantOp::create(
291 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
293 bool isColMajor = op.getTranspose().value_or(
false);
294 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
295 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
297 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
298 op, coopType, bufferPtr, strideValue, layout);
305struct WmmaStoreOpToSPIRVLowering final
306 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
310 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter)
const override {
312 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
315 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
318 adaptor.getIndices(), loc, rewriter);
320 int64_t stride = op.getLeadDimension().getSExtValue();
321 IntegerType i32Type = rewriter.getI32Type();
322 auto strideValue = spirv::ConstantOp::create(
323 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
325 bool isColMajor = op.getTranspose().value_or(
false);
326 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
327 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
329 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
330 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
337struct WmmaMmaOpToSPIRVLowering final
338 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
342 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
344 ConversionPatternRewriter &rewriter)
const override {
345 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
346 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
358 using namespace mlir;
360 patterns.
add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
361 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
362 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
363 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
365 patterns.
add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
376 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
377 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
378 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
382 spirv::Scope::Subgroup, use);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...