25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
41 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
52 case gpu::MMAElementwiseOp::SUBF:
58 case gpu::MMAElementwiseOp::DIVF:
61 case gpu::MMAElementwiseOp::DIVS:
64 case gpu::MMAElementwiseOp::DIVU:
67 case gpu::MMAElementwiseOp::NEGATEF:
70 case gpu::MMAElementwiseOp::NEGATES:
73 case gpu::MMAElementwiseOp::EXTF:
83 assert(!operands.empty());
85 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
88 return isa<spirv::CooperativeMatrixType>(operands.front().
getType());
94 struct WmmaConstantOpToSPIRVLowering final
95 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
99 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter)
const override {
101 Value cst = llvm::getSingleElement(adaptor.getOperands());
102 auto coopType = getTypeConverter()->convertType(op.getType());
104 return rewriter.notifyMatchFailure(op,
"type conversion failed");
106 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
113 struct WmmaExtractOpToSPIRVLowering final
114 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
118 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter)
const override {
120 Value matrix = adaptor.getMatrix();
122 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
125 return rewriter.notifyMatchFailure(op,
"type conversion failed");
127 SmallVector<int32_t> intValues;
128 for (Value val : op.getIndices()) {
129 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
130 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
132 return rewriter.notifyMatchFailure(op,
"indices must be constants");
136 Type elementType = coopType.getElementType();
137 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
138 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
145 struct WmmaInsertOpToSPIRVLowering final
146 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
150 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
151 ConversionPatternRewriter &rewriter)
const override {
152 Value value = adaptor.getValue();
153 Value matrix = adaptor.getMatrix();
154 auto coopType = getTypeConverter()->convertType(matrix.getType());
156 return rewriter.notifyMatchFailure(op,
"type conversion failed");
158 SmallVector<int32_t> intValues;
159 for (Value val : op.getIndices()) {
160 if (
auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
161 intValues.push_back(
static_cast<int32_t
>(constOp.value()));
163 return rewriter.notifyMatchFailure(op,
"indices must be constants");
167 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
168 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
175 struct WmmaElementwiseOpToSPIRVDefaultLowering final
176 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
180 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter)
const override {
184 return rewriter.notifyMatchFailure(op,
185 "not all operands are coop matrices");
188 auto coopType = getTypeConverter()->convertType(op.getType());
190 return rewriter.notifyMatchFailure(op,
"type conversion failed");
199 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
200 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
204 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override {
206 if (adaptor.getOperands().size() != 2)
211 return rewriter.notifyMatchFailure(op,
212 "not all operands are coop matrices");
215 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
220 Value lhs = op.getOperands().front();
221 Value rhs = op.getOperands().back();
222 Value splat =
nullptr;
223 Value matrix =
nullptr;
224 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
225 splat = adaptor.getOperands().front();
226 matrix = adaptor.getOperands().back();
227 }
else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228 matrix = adaptor.getOperands().front();
229 splat = adaptor.getOperands().back();
231 if (!splat || !matrix)
232 return rewriter.notifyMatchFailure(op,
"no splat operand");
236 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
238 return rewriter.notifyMatchFailure(op,
239 "splat is not a composite construct");
242 scalar = llvm::getSingleElement(cc.getConstituents());
244 auto coopType = getTypeConverter()->convertType(op.getType());
246 return rewriter.notifyMatchFailure(op,
"type conversion failed");
247 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
248 op, coopType, ValueRange{matrix, scalar});
263 struct WmmaLoadOpToSPIRVLowering final
268 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
270 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
273 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
274 MemRefType memrefType = op.getSrcMemref().getType();
277 adaptor.getIndices(), loc, rewriter);
284 int64_t stride = op.getLeadDimension().getSExtValue();
286 auto strideValue = rewriter.
create<spirv::ConstantOp>(
289 bool isColMajor = op.getTranspose().value_or(
false);
290 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
291 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
294 op, coopType, bufferPtr, strideValue, layout);
301 struct WmmaStoreOpToSPIRVLowering final
306 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
308 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
311 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
314 adaptor.getIndices(), loc, rewriter);
316 int64_t stride = op.getLeadDimension().getSExtValue();
318 auto strideValue = rewriter.
create<spirv::ConstantOp>(
321 bool isColMajor = op.getTranspose().value_or(
false);
322 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
323 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
326 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
333 struct WmmaMmaOpToSPIRVLowering final
338 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
342 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
354 using namespace mlir;
356 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
357 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
358 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
359 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
361 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
372 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
373 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
374 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
378 spirv::Scope::Subgroup, use);
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
@ Type
An inlay hint that for a type annotation.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...