30 ConversionPatternRewriter &rewriter) {
31 if (!llvm::all_of(operands, [](
Value value) {
34 return rewriter.notifyMatchFailure(
35 op,
"cannot convert if operands aren't of LLVM type.");
42static constexpr StringRef kInvalidCaseStr =
"Unsupported WMMA variant.";
44static NVVM::MMAFrag convertOperand(StringRef operandName) {
45 if (operandName ==
"AOp")
46 return NVVM::MMAFrag::a;
47 if (operandName ==
"BOp")
48 return NVVM::MMAFrag::b;
49 if (operandName ==
"COp")
50 return NVVM::MMAFrag::c;
51 llvm_unreachable(
"Unknown operand name");
56 return NVVM::MMATypes::f16;
58 return type.
getOperand() ==
"COp" ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32;
62 return NVVM::MMATypes::s8;
64 return NVVM::MMATypes::u8;
67 return NVVM::MMATypes::s32;
68 llvm_unreachable(
"Unsupported type");
75struct WmmaLoadOpToNVVMLowering
77 using ConvertOpToLLVMPattern<
78 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
83 ConversionPatternRewriter &rewriter)
const override {
84 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row;
93 gpu::MMAMatrixType retType =
94 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
95 ArrayRef<int64_t> retTypeShape = retType.
getShape();
105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
115 NVVM::MMAFrag frag = convertOperand(retType.
getOperand());
117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
118 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
121 Location loc = op->
getLoc();
126 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127 adaptor.getSrcMemref(), adaptor.getIndices());
129 Value leadingDim = LLVM::ConstantOp::create(
130 rewriter, loc, rewriter.getI32Type(),
131 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
132 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
133 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
142struct WmmaStoreOpToNVVMLowering
144 using ConvertOpToLLVMPattern<
145 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
148 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
150 ConversionPatternRewriter &rewriter)
const override {
151 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
155 Location loc = op->
getLoc();
157 SmallVector<Value, 4> storeOpOperands;
160 gpu::MMAMatrixType srcType =
161 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
162 ArrayRef<int64_t> srcTypeShape = srcType.
getShape();
163 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
164 ? NVVM::MMALayout::col
165 : NVVM::MMALayout::row;
167 int64_t m = srcTypeShape[0];
168 int64_t n = srcTypeShape[1];
169 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
170 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
171 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
173 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
174 for (
unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
176 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
177 storeOpOperands.push_back(toUse);
182 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
183 adaptor.getDstMemref(), adaptor.getIndices());
184 Value leadingDim = LLVM::ConstantOp::create(
185 rewriter, loc, rewriter.getI32Type(),
186 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
187 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
188 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
195struct WmmaMmaOpToNVVMLowering
197 using ConvertOpToLLVMPattern<
198 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
201 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
203 ConversionPatternRewriter &rewriter)
const override {
204 Operation *op = subgroupMmaComputeOp.getOperation();
208 Location loc = op->
getLoc();
214 SmallVector<Value> unpackedOps;
216 auto unpackOp = [&](Value operand) {
217 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
218 for (
size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
219 Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
220 unpackedOps.push_back(toUse);
226 gpu::MMAMatrixType aType =
227 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
228 ArrayRef<int64_t> aTypeShape = aType.
getShape();
229 gpu::MMAMatrixType cType =
230 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
231 ArrayRef<int64_t> cTypeShape = cType.
getShape();
232 int64_t m = cTypeShape[0];
233 int64_t n = cTypeShape[1];
234 int64_t k = aTypeShape[1];
235 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
236 ? NVVM::MMALayout::col
237 : NVVM::MMALayout::row;
238 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
239 ? NVVM::MMALayout::col
240 : NVVM::MMALayout::row;
243 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
245 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
248 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
249 if (bElementType != sourceType)
250 return rewriter.notifyMatchFailure(
251 op,
"WMMA compute op input matrix element types must match.");
253 unpackOp(adaptor.getOpA());
254 unpackOp(adaptor.getOpB());
255 unpackOp(adaptor.getOpC());
257 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
258 op, adaptor.getOpC().
getType(), m, n, k, aLayout, bLayout, sourceType,
259 destType, unpackedOps);
265struct WmmaConstantOpToNVVMLowering
267 using ConvertOpToLLVMPattern<
268 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
271 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
273 ConversionPatternRewriter &rewriter)
const override {
275 adaptor.getOperands(), rewriter)))
277 Location loc = subgroupMmaConstantOp.getLoc();
278 Value cst = adaptor.getOperands()[0];
280 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
282 if (
auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
283 Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
284 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
285 Value idx = LLVM::ConstantOp::create(rewriter, loc,
286 rewriter.getI32Type(), vecEl);
287 vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
292 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
293 for (
size_t i : llvm::seq(
size_t(0), type.getBody().size())) {
295 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
297 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
306 if (
auto vecType = dyn_cast<VectorType>(
lhs.getType()))
307 i1Type = VectorType::get(vecType.getShape(), i1Type);
308 Value cmp = LLVM::FCmpOp::create(
309 builder, loc, i1Type,
310 isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs,
rhs);
311 Value sel = LLVM::SelectOp::create(builder, loc, cmp,
lhs,
rhs);
312 Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
313 LLVM::FCmpPredicate::uno,
lhs,
rhs);
314 Value nan = LLVM::ConstantOp::create(
315 builder, loc,
lhs.getType(),
317 APFloat::getQNaN(floatType.getFloatSemantics())));
318 return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
322 gpu::MMAElementwiseOp op,
325 case gpu::MMAElementwiseOp::ADDF:
326 return LLVM::FAddOp::create(builder, loc, operands[0].
getType(), operands);
327 case gpu::MMAElementwiseOp::MULF:
328 return LLVM::FMulOp::create(builder, loc, operands[0].
getType(), operands);
329 case gpu::MMAElementwiseOp::DIVF:
330 return LLVM::FDivOp::create(builder, loc, operands[0].
getType(), operands);
331 case gpu::MMAElementwiseOp::MAXF:
332 return createMinMaxF(builder, loc, operands[0], operands[1],
334 case gpu::MMAElementwiseOp::MINF:
335 return createMinMaxF(builder, loc, operands[0], operands[1],
338 llvm_unreachable(
"unknown op");
343struct WmmaElementwiseOpToNVVMLowering
345 using ConvertOpToLLVMPattern<
346 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
349 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
351 ConversionPatternRewriter &rewriter)
const override {
353 adaptor.getOperands(), rewriter)))
355 Location loc = subgroupMmaElementwiseOp.getLoc();
356 size_t numOperands = adaptor.getOperands().size();
358 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
359 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
360 for (
size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
361 SmallVector<Value> extractedOperands;
362 for (
size_t opIdx = 0; opIdx < numOperands; opIdx++) {
363 extractedOperands.push_back(LLVM::ExtractValueOp::create(
364 rewriter, loc, adaptor.getOperands()[opIdx], i));
367 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
370 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
372 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
381 NVVM::MMAFrag frag = convertOperand(type.
getOperand());
385 std::pair<Type, unsigned> typeInfo =
387 return LLVM::LLVMStructType::getLiteral(
394 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
395 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
396 WmmaElementwiseOpToNVVMLowering>(converter, benefit);
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Type getElementType(Type type)
Determine the element type of type.
FloatAttr getFloatAttr(Type type, double value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.