31 if (!llvm::all_of(operands, [](
Value value) {
35 op,
"cannot convert if operands aren't of LLVM type.");
42 static constexpr StringRef kInvalidCaseStr =
"Unsupported WMMA variant.";
44 static NVVM::MMAFrag convertOperand(StringRef operandName) {
45 if (operandName ==
"AOp")
46 return NVVM::MMAFrag::a;
47 if (operandName ==
"BOp")
48 return NVVM::MMAFrag::b;
49 if (operandName ==
"COp")
50 return NVVM::MMAFrag::c;
51 llvm_unreachable(
"Unknown operand name");
56 return NVVM::MMATypes::f16;
58 return type.
getOperand() ==
"COp" ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32;
62 return NVVM::MMATypes::s8;
64 return NVVM::MMATypes::u8;
67 return NVVM::MMATypes::s32;
68 llvm_unreachable(
"Unsupported type");
75 struct WmmaLoadOpToNVVMLowering
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
84 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row;
94 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
115 NVVM::MMAFrag frag = convertOperand(retType.
getOperand());
117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
126 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127 adaptor.getSrcMemref(), adaptor.getIndices());
129 Value leadingDim = rewriter.
create<LLVM::ConstantOp>(
131 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
133 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
142 struct WmmaStoreOpToNVVMLowering
148 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
151 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
161 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
163 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
164 ? NVVM::MMALayout::col
165 : NVVM::MMALayout::row;
167 int64_t m = srcTypeShape[0];
168 int64_t n = srcTypeShape[1];
169 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
170 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
173 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
174 for (
unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
176 rewriter.
create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
177 storeOpOperands.push_back(toUse);
182 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
183 adaptor.getDstMemref(), adaptor.getIndices());
184 Value leadingDim = rewriter.
create<LLVM::ConstantOp>(
186 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
188 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
195 struct WmmaMmaOpToNVVMLowering
201 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
204 Operation *op = subgroupMmaComputeOp.getOperation();
216 auto unpackOp = [&](
Value operand) {
217 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
218 for (
size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
219 Value toUse = rewriter.
create<LLVM::ExtractValueOp>(loc, operand, i);
220 unpackedOps.push_back(toUse);
227 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
230 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
232 int64_t m = cTypeShape[0];
233 int64_t n = cTypeShape[1];
234 int64_t k = aTypeShape[1];
235 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
236 ? NVVM::MMALayout::col
237 : NVVM::MMALayout::row;
238 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
239 ? NVVM::MMALayout::col
240 : NVVM::MMALayout::row;
243 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
248 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
249 if (bElementType != sourceType)
251 op,
"WMMA compute op input matrix element types must match.");
253 unpackOp(adaptor.getOpA());
254 unpackOp(adaptor.getOpB());
255 unpackOp(adaptor.getOpC());
258 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
259 destType, unpackedOps);
265 struct WmmaConstantOpToNVVMLowering
271 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
275 adaptor.getOperands(), rewriter)))
277 Location loc = subgroupMmaConstantOp.getLoc();
278 Value cst = adaptor.getOperands()[0];
280 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
282 if (
auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
283 Value vecCst = rewriter.
create<LLVM::PoisonOp>(loc, vecType);
284 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
287 vecCst = rewriter.
create<LLVM::InsertElementOp>(loc, vecType, vecCst,
292 Value matrixStruct = rewriter.
create<LLVM::PoisonOp>(loc, type);
293 for (
size_t i : llvm::seq(
size_t(0), type.getBody().size())) {
295 rewriter.
create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
297 rewriter.
replaceOp(subgroupMmaConstantOp, matrixStruct);
303 Value rhs,
bool isMin) {
306 if (
auto vecType = dyn_cast<VectorType>(lhs.
getType()))
309 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
311 Value sel = builder.
create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
313 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
317 APFloat::getQNaN(floatType.getFloatSemantics())));
318 return builder.
create<LLVM::SelectOp>(loc, isNan, nan, sel);
322 gpu::MMAElementwiseOp op,
325 case gpu::MMAElementwiseOp::ADDF:
326 return builder.
create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
327 case gpu::MMAElementwiseOp::MULF:
328 return builder.
create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
329 case gpu::MMAElementwiseOp::DIVF:
330 return builder.
create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
331 case gpu::MMAElementwiseOp::MAXF:
332 return createMinMaxF(builder, loc, operands[0], operands[1],
334 case gpu::MMAElementwiseOp::MINF:
335 return createMinMaxF(builder, loc, operands[0], operands[1],
338 llvm_unreachable(
"unknown op");
343 struct WmmaElementwiseOpToNVVMLowering
349 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
353 adaptor.getOperands(), rewriter)))
355 Location loc = subgroupMmaElementwiseOp.getLoc();
356 size_t numOperands = adaptor.getOperands().size();
358 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
359 Value matrixStruct = rewriter.
create<LLVM::PoisonOp>(loc, destType);
360 for (
size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
362 for (
size_t opIdx = 0; opIdx < numOperands; opIdx++) {
363 extractedOperands.push_back(rewriter.
create<LLVM::ExtractValueOp>(
364 loc, adaptor.getOperands()[opIdx], i));
367 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
370 rewriter.
create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
372 rewriter.
replaceOp(subgroupMmaElementwiseOp, matrixStruct);
381 NVVM::MMAFrag frag = convertOperand(type.
getOperand());
385 std::pair<Type, unsigned> typeInfo =
387 return LLVM::LLVMStructType::getLiteral(
394 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
395 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
396 WmmaElementwiseOpToNVVMLowering>(converter, benefit);
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.