25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringSwitch.h"
41 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
47 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
49 case gpu::MMAElementwiseOp::ADDI:
50 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
52 case gpu::MMAElementwiseOp::SUBF:
53 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
55 case gpu::MMAElementwiseOp::SUBI:
56 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
58 case gpu::MMAElementwiseOp::MULF:
59 builder.replaceOpWithNewOp<spirv::FMulOp>(op, coopType, operands);
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
76 case gpu::MMAElementwiseOp::EXTF:
77 case gpu::MMAElementwiseOp::TRUNCF:
78 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
87 assert(!operands.empty());
89 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
92 return isa<spirv::CooperativeMatrixType>(operands.front().
getType());
97 return elementType && elementType.isSigned();
100static spirv::CooperativeMatrixOperandsKHR
105 using Operands = spirv::CooperativeMatrixOperandsKHR;
107 Operands operands = Operands::None;
109 operands |= Operands::ASigned;
111 operands |= Operands::BSigned;
113 operands |= Operands::CSigned;
115 operands |= Operands::ResultSigned;
122struct WmmaConstantOpToSPIRVLowering final
123 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
127 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter)
const override {
129 Value cst = llvm::getSingleElement(adaptor.getOperands());
130 auto coopType = getTypeConverter()->convertType(op.getType());
132 return rewriter.notifyMatchFailure(op,
"type conversion failed");
134 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
141struct WmmaExtractOpToSPIRVLowering final
142 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
146 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const override {
148 Value matrix = adaptor.getMatrix();
150 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
153 return rewriter.notifyMatchFailure(op,
"type conversion failed");
155 SmallVector<int32_t> intValues;
156 for (Value val : op.getIndices()) {
157 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
158 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
160 return rewriter.notifyMatchFailure(op,
"indices must be constants");
164 Type elementType = coopType.getElementType();
165 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
166 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
173struct WmmaInsertOpToSPIRVLowering final
174 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
178 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter)
const override {
180 Value value = adaptor.getValue();
181 Value matrix = adaptor.getMatrix();
182 auto coopType = getTypeConverter()->convertType(matrix.getType());
184 return rewriter.notifyMatchFailure(op,
"type conversion failed");
186 SmallVector<int32_t> intValues;
187 for (Value val : op.getIndices()) {
188 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
189 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
191 return rewriter.notifyMatchFailure(op,
"indices must be constants");
195 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
196 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
203struct WmmaElementwiseOpToSPIRVDefaultLowering final
204 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
208 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter)
const override {
212 return rewriter.notifyMatchFailure(op,
213 "not all operands are coop matrices");
216 auto coopType = getTypeConverter()->convertType(op.getType());
218 return rewriter.notifyMatchFailure(op,
"type conversion failed");
227struct WmmaElementwiseOpToSPIRVScalarMulLowering final
228 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
232 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter)
const override {
234 if (adaptor.getOperands().size() != 2)
239 return rewriter.notifyMatchFailure(op,
240 "not all operands are coop matrices");
243 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
248 Value
lhs = op.getOperands().front();
249 Value
rhs = op.getOperands().back();
250 Value splat =
nullptr;
251 Value matrix =
nullptr;
252 if (
lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
253 splat = adaptor.getOperands().front();
254 matrix = adaptor.getOperands().back();
255 }
else if (
rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
256 matrix = adaptor.getOperands().front();
257 splat = adaptor.getOperands().back();
259 if (!splat || !matrix)
260 return rewriter.notifyMatchFailure(op,
"no splat operand");
264 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
266 return rewriter.notifyMatchFailure(op,
267 "splat is not a composite construct");
270 scalar = llvm::getSingleElement(cc.getConstituents());
272 auto coopType = getTypeConverter()->convertType(op.getType());
274 return rewriter.notifyMatchFailure(op,
"type conversion failed");
275 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
291struct WmmaLoadOpToSPIRVLowering final
292 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
296 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter)
const override {
298 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
301 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
302 MemRefType memrefType = op.getSrcMemref().getType();
305 adaptor.getIndices(), loc, rewriter);
310 return rewriter.notifyMatchFailure(op,
"type conversion failed");
312 int64_t stride = op.getLeadDimension().getSExtValue();
313 IntegerType i32Type = rewriter.getI32Type();
314 auto strideValue = spirv::ConstantOp::create(
315 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
317 bool isColMajor = op.getTranspose().value_or(
false);
318 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
319 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
321 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
322 op, coopType, bufferPtr, strideValue, layout);
329struct WmmaStoreOpToSPIRVLowering final
330 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
334 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override {
336 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
339 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
342 adaptor.getIndices(), loc, rewriter);
344 int64_t stride = op.getLeadDimension().getSExtValue();
345 IntegerType i32Type = rewriter.getI32Type();
346 auto strideValue = spirv::ConstantOp::create(
347 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
349 bool isColMajor = op.getTranspose().value_or(
false);
350 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
351 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
353 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
354 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
361struct WmmaMmaOpToSPIRVLowering final
362 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
366 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
368 ConversionPatternRewriter &rewriter)
const override {
370 dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpA().getType());
372 dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpB().getType());
374 dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpC().getType());
377 subgroupMmaComputeOp.getResult().
getType());
378 if (!aType || !bType || !cType || !resultType)
379 return rewriter.notifyMatchFailure(subgroupMmaComputeOp,
380 "type conversion failed");
382 using Operands = spirv::CooperativeMatrixOperandsKHR;
385 spirv::CooperativeMatrixOperandsKHRAttr operandsAttr;
386 if (operands != Operands::None)
387 operandsAttr = spirv::CooperativeMatrixOperandsKHRAttr::get(
388 rewriter.getContext(), operands);
390 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
391 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
392 adaptor.getOpC(), operandsAttr);
403 using namespace mlir;
405 patterns.
add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
406 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
407 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
408 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
410 patterns.
add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
421 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
422 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
423 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
427 spirv::Scope::Subgroup, use);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
static spirv::CooperativeMatrixOperandsKHR getSignedCoopMatrixOperands(spirv::CooperativeMatrixType aType, spirv::CooperativeMatrixType bType, spirv::CooperativeMatrixType cType, spirv::CooperativeMatrixType resultType)
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
static bool hasSignedIntegerElementType(spirv::CooperativeMatrixType type)