28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
44 gpu::SubgroupMmaElementwiseOp op,
Type coopType,
46 assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
49 switch (op.getOpType()) {
50 case gpu::MMAElementwiseOp::ADDF:
56 case gpu::MMAElementwiseOp::SUBF:
62 case gpu::MMAElementwiseOp::DIVF:
65 case gpu::MMAElementwiseOp::DIVS:
68 case gpu::MMAElementwiseOp::DIVU:
71 case gpu::MMAElementwiseOp::NEGATEF:
74 case gpu::MMAElementwiseOp::NEGATES:
77 case gpu::MMAElementwiseOp::EXTF:
87 assert(!operands.empty());
89 llvm::map_range(operands, [](
Value v) {
return v.
getType(); })))
92 return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
99 struct WmmaConstantOpToSPIRVLowering final
100 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
104 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 assert(adaptor.getOperands().size() == 1);
107 Value cst = adaptor.getOperands().front();
108 auto coopType = getTypeConverter()->convertType(op.getType());
110 return rewriter.notifyMatchFailure(op,
"type conversion failed");
112 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
119 struct WmmaElementwiseOpToSPIRVDefaultLowering final
120 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
124 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
125 ConversionPatternRewriter &rewriter)
const override {
128 return rewriter.notifyMatchFailure(op,
129 "not all operands are coop matrices");
132 auto coopType = getTypeConverter()->convertType(op.getType());
134 return rewriter.notifyMatchFailure(op,
"type conversion failed");
143 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
144 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
148 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
149 ConversionPatternRewriter &rewriter)
const override {
150 if (adaptor.getOperands().size() != 2)
155 return rewriter.notifyMatchFailure(op,
156 "not all operands are coop matrices");
159 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
164 Value lhs = op.getOperands().front();
165 Value rhs = op.getOperands().back();
166 Value splat =
nullptr;
167 Value matrix =
nullptr;
168 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
169 splat = adaptor.getOperands().front();
170 matrix = adaptor.getOperands().back();
171 }
else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
172 matrix = adaptor.getOperands().front();
173 splat = adaptor.getOperands().back();
175 if (!splat || !matrix)
176 return rewriter.notifyMatchFailure(op,
"no splat operand");
180 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
182 return rewriter.notifyMatchFailure(op,
183 "splat is not a composite construct");
186 assert(cc.getConstituents().size() == 1);
187 scalar = cc.getConstituents().front();
189 auto coopType = getTypeConverter()->convertType(op.getType());
191 return rewriter.notifyMatchFailure(op,
"type conversion failed");
192 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
193 op, coopType, ValueRange{matrix, scalar});
208 struct WmmaLoadOpToSPIRVLowering final
213 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
215 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
218 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
219 MemRefType memrefType = op.getSrcMemref().getType();
222 adaptor.getIndices(), loc, rewriter);
229 int64_t stride = op.getLeadDimension().getSExtValue();
231 auto strideValue = rewriter.
create<spirv::ConstantOp>(
234 bool isColMajor = op.getTranspose().value_or(
false);
235 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
236 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
239 op, coopType, bufferPtr, strideValue, layout);
246 struct WmmaStoreOpToSPIRVLowering final
251 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
253 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
256 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
259 adaptor.getIndices(), loc, rewriter);
261 int64_t stride = op.getLeadDimension().getSExtValue();
263 auto strideValue = rewriter.
create<spirv::ConstantOp>(
266 bool isColMajor = op.getTranspose().value_or(
false);
267 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
268 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
271 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
278 struct WmmaMmaOpToSPIRVLowering final
283 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
287 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
305 struct WmmaLoadOpToSPIRVLowering final
310 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
313 Location loc = subgroupMmaLoadMatrixOp->getLoc();
314 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
317 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
319 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
322 adaptor.getIndices(), loc, rewriter);
327 "type conversion failed");
329 int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
331 auto strideValue = rewriter.
create<spirv::ConstantOp>(
333 bool isColMajor =
static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
334 auto columnMajor = rewriter.
create<spirv::ConstantOp>(
337 subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
338 spirv::MemoryAccessAttr());
345 struct WmmaStoreOpToSPIRVLowering final
350 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
353 Location loc = subgroupMmaStoreMatrixOp->getLoc();
355 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType());
357 *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
358 adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
359 int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
361 auto strideValue = rewriter.
create<spirv::ConstantOp>(
364 static_cast<bool>(subgroupMmaStoreMatrixOp.getTranspose());
365 auto columnMajor = rewriter.
create<spirv::ConstantOp>(
368 subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
369 columnMajor, spirv::MemoryAccessAttr());
376 struct WmmaMmaOpToSPIRVLowering final
381 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
385 subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
386 adaptor.getOpB(), adaptor.getOpC());
397 using namespace mlir;
399 patterns.
add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
400 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
401 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
403 patterns.
add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
409 using namespace mlir;
411 patterns.
add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
412 nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
413 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
415 patterns.
add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
426 elementType, spirv::Scope::Subgroup, retTypeShape[0],
437 .Case(
"AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
438 .Case(
"BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
439 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
443 spirv::Scope::Subgroup, use);
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the NV Cooperative Mat...
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter, bool useNVTypes=false)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix type conversion to the type converter.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.