31 if (!llvm::all_of(operands, [](
Value value) {
35 op,
"cannot convert if operands aren't of LLVM type.");
42 static constexpr StringRef kInvalidCaseStr =
"Unsupported WMMA variant.";
44 static NVVM::MMAFrag convertOperand(StringRef operandName) {
45 if (operandName ==
"AOp")
46 return NVVM::MMAFrag::a;
47 if (operandName ==
"BOp")
48 return NVVM::MMAFrag::b;
49 if (operandName ==
"COp")
50 return NVVM::MMAFrag::c;
51 llvm_unreachable(
"Unknown operand name");
56 return NVVM::MMATypes::f16;
58 return type.
getOperand() ==
"COp" ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32;
62 return NVVM::MMATypes::s8;
64 return NVVM::MMATypes::u8;
67 return NVVM::MMATypes::s32;
68 llvm_unreachable(
"Unsupported type");
75 struct WmmaLoadOpToNVVMLowering
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
84 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row;
94 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
115 NVVM::MMAFrag frag = convertOperand(retType.
getOperand());
117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
124 Value dataPtr = getStridedElementPtr(
125 loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126 adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
128 Value leadingDim = rewriter.
create<LLVM::ConstantOp>(
130 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
132 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
141 struct WmmaStoreOpToNVVMLowering
147 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
150 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
160 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
162 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
163 ? NVVM::MMALayout::col
164 : NVVM::MMALayout::row;
166 int64_t m = srcTypeShape[0];
167 int64_t n = srcTypeShape[1];
168 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
169 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
172 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
173 for (
unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
175 rewriter.
create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
176 storeOpOperands.push_back(toUse);
179 Value dataPtr = getStridedElementPtr(
181 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182 adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183 Value leadingDim = rewriter.
create<LLVM::ConstantOp>(
185 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
187 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
194 struct WmmaMmaOpToNVVMLowering
200 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
203 Operation *op = subgroupMmaComputeOp.getOperation();
215 auto unpackOp = [&](
Value operand) {
216 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
217 for (
size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
218 Value toUse = rewriter.
create<LLVM::ExtractValueOp>(loc, operand, i);
219 unpackedOps.push_back(toUse);
226 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
229 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
231 int64_t m = cTypeShape[0];
232 int64_t n = cTypeShape[1];
233 int64_t k = aTypeShape[1];
234 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
235 ? NVVM::MMALayout::col
236 : NVVM::MMALayout::row;
237 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
238 ? NVVM::MMALayout::col
239 : NVVM::MMALayout::row;
242 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
247 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
248 if (bElementType != sourceType)
250 op,
"WMMA compute op input matrix element types must match.");
252 unpackOp(adaptor.getOpA());
253 unpackOp(adaptor.getOpB());
254 unpackOp(adaptor.getOpC());
257 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
258 destType, unpackedOps);
264 struct WmmaConstantOpToNVVMLowering
270 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
274 adaptor.getOperands(), rewriter)))
276 Location loc = subgroupMmaConstantOp.getLoc();
277 Value cst = adaptor.getOperands()[0];
279 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
281 if (
auto vecType = dyn_cast<VectorType>(type.
getBody()[0])) {
282 Value vecCst = rewriter.
create<LLVM::UndefOp>(loc, vecType);
283 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
286 vecCst = rewriter.
create<LLVM::InsertElementOp>(loc, vecType, vecCst,
291 Value matrixStruct = rewriter.
create<LLVM::UndefOp>(loc, type);
292 for (
size_t i : llvm::seq(
size_t(0), type.
getBody().size())) {
294 rewriter.
create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
296 rewriter.
replaceOp(subgroupMmaConstantOp, matrixStruct);
302 Value rhs,
bool isMin) {
305 if (
auto vecType = dyn_cast<VectorType>(lhs.
getType()))
308 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
310 Value sel = builder.
create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
312 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
316 APFloat::getQNaN(floatType.getFloatSemantics())));
317 return builder.
create<LLVM::SelectOp>(loc, isNan, nan, sel);
321 gpu::MMAElementwiseOp op,
324 case gpu::MMAElementwiseOp::ADDF:
325 return builder.
create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
326 case gpu::MMAElementwiseOp::MULF:
327 return builder.
create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
328 case gpu::MMAElementwiseOp::DIVF:
329 return builder.
create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
330 case gpu::MMAElementwiseOp::MAXF:
331 return createMinMaxF(builder, loc, operands[0], operands[1],
333 case gpu::MMAElementwiseOp::MINF:
334 return createMinMaxF(builder, loc, operands[0], operands[1],
337 llvm_unreachable(
"unknown op");
342 struct WmmaElementwiseOpToNVVMLowering
348 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
352 adaptor.getOperands(), rewriter)))
354 Location loc = subgroupMmaElementwiseOp.getLoc();
355 size_t numOperands = adaptor.getOperands().size();
357 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
358 Value matrixStruct = rewriter.
create<LLVM::UndefOp>(loc, destType);
359 for (
size_t i = 0, e = destType.
getBody().size(); i < e; ++i) {
361 for (
size_t opIdx = 0; opIdx < numOperands; opIdx++) {
362 extractedOperands.push_back(rewriter.
create<LLVM::ExtractValueOp>(
363 loc, adaptor.getOperands()[opIdx], i));
366 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
369 rewriter.
create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
371 rewriter.
replaceOp(subgroupMmaElementwiseOp, matrixStruct);
380 NVVM::MMAFrag frag = convertOperand(type.
getOperand());
384 std::pair<Type, unsigned> typeInfo =
392 patterns.
add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
393 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
394 WmmaElementwiseOpToNVVMLowering>(converter);
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.