31 ConversionPatternRewriter &rewriter) {
32 if (!llvm::all_of(operands, [](
Value value) {
35 return rewriter.notifyMatchFailure(
36 op,
"cannot convert if operands aren't of LLVM type.");
43static constexpr StringRef kInvalidCaseStr =
"Unsupported WMMA variant.";
45static NVVM::MMAFrag convertOperand(StringRef operandName) {
46 if (operandName ==
"AOp")
47 return NVVM::MMAFrag::a;
48 if (operandName ==
"BOp")
49 return NVVM::MMAFrag::b;
50 if (operandName ==
"COp")
51 return NVVM::MMAFrag::c;
52 llvm_unreachable(
"Unknown operand name");
57 return NVVM::MMATypes::f16;
59 return type.
getOperand() ==
"COp" ? NVVM::MMATypes::f32
60 : NVVM::MMATypes::tf32;
62 return NVVM::MMATypes::f64;
64 return NVVM::MMATypes::s8;
66 return NVVM::MMATypes::u8;
69 return NVVM::MMATypes::s32;
70 llvm_unreachable(
"Unsupported type");
77struct WmmaLoadOpToNVVMLowering
79 using ConvertOpToLLVMPattern<
80 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
83 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
85 ConversionPatternRewriter &rewriter)
const override {
86 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
92 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
93 ? NVVM::MMALayout::col
94 : NVVM::MMALayout::row;
95 gpu::MMAMatrixType retType =
96 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
97 ArrayRef<int64_t> retTypeShape = retType.
getShape();
107 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
111 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
115 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
117 NVVM::MMAFrag frag = convertOperand(retType.
getOperand());
119 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
120 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
123 Location loc = op->
getLoc();
128 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
129 adaptor.getSrcMemref(), adaptor.getIndices());
131 Value leadingDim = LLVM::ConstantOp::create(
132 rewriter, loc, rewriter.getI32Type(),
133 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
134 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
135 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
144struct WmmaStoreOpToNVVMLowering
146 using ConvertOpToLLVMPattern<
147 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
150 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
152 ConversionPatternRewriter &rewriter)
const override {
153 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
157 Location loc = op->
getLoc();
159 SmallVector<Value, 4> storeOpOperands;
162 gpu::MMAMatrixType srcType =
163 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
164 ArrayRef<int64_t> srcTypeShape = srcType.
getShape();
165 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
166 ? NVVM::MMALayout::col
167 : NVVM::MMALayout::row;
169 int64_t m = srcTypeShape[0];
170 int64_t n = srcTypeShape[1];
171 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
172 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
173 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
175 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
176 for (
unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
178 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
179 storeOpOperands.push_back(toUse);
184 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
185 adaptor.getDstMemref(), adaptor.getIndices());
186 Value leadingDim = LLVM::ConstantOp::create(
187 rewriter, loc, rewriter.getI32Type(),
188 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
189 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
190 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
197struct WmmaMmaOpToNVVMLowering
199 using ConvertOpToLLVMPattern<
200 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
203 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
205 ConversionPatternRewriter &rewriter)
const override {
206 Operation *op = subgroupMmaComputeOp.getOperation();
210 Location loc = op->
getLoc();
216 SmallVector<Value> unpackedOps;
217 auto unpackOp = [&](Value operand) {
219 if (!isa<LLVM::LLVMStructType>(operand.getType())) {
220 unpackedOps.push_back(operand);
224 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
225 for (
size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
226 Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
227 unpackedOps.push_back(toUse);
233 gpu::MMAMatrixType aType =
234 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
235 ArrayRef<int64_t> aTypeShape = aType.
getShape();
236 gpu::MMAMatrixType cType =
237 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
238 ArrayRef<int64_t> cTypeShape = cType.
getShape();
239 int64_t m = cTypeShape[0];
240 int64_t n = cTypeShape[1];
241 int64_t k = aTypeShape[1];
242 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
243 ? NVVM::MMALayout::col
244 : NVVM::MMALayout::row;
245 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
246 ? NVVM::MMALayout::col
247 : NVVM::MMALayout::row;
250 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
252 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
255 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
256 if (bElementType != sourceType)
257 return rewriter.notifyMatchFailure(
258 op,
"WMMA compute op input matrix element types must match.");
260 unpackOp(adaptor.getOpA());
261 unpackOp(adaptor.getOpB());
262 unpackOp(adaptor.getOpC());
264 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
265 op, adaptor.getOpC().
getType(), m, n, k, aLayout, bLayout, sourceType,
266 destType, unpackedOps);
272struct WmmaConstantOpToNVVMLowering
274 using ConvertOpToLLVMPattern<
275 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
278 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
280 ConversionPatternRewriter &rewriter)
const override {
282 adaptor.getOperands(), rewriter)))
284 Location loc = subgroupMmaConstantOp.getLoc();
285 Value cst = adaptor.getOperands()[0];
287 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
289 auto structType = dyn_cast<LLVM::LLVMStructType>(type);
291 rewriter.replaceOp(subgroupMmaConstantOp, cst);
295 if (
auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
296 Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
297 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
298 Value idx = LLVM::ConstantOp::create(rewriter, loc,
299 rewriter.getI32Type(), vecEl);
300 vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
305 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
306 for (
size_t i : llvm::seq(
size_t(0), structType.getBody().size())) {
308 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
310 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
319 if (
auto vecType = dyn_cast<VectorType>(
lhs.getType()))
320 i1Type = VectorType::get(vecType.getShape(), i1Type);
321 Value cmp = LLVM::FCmpOp::create(
322 builder, loc, i1Type,
323 isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs,
rhs);
324 Value sel = LLVM::SelectOp::create(builder, loc, cmp,
lhs,
rhs);
325 Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
326 LLVM::FCmpPredicate::uno,
lhs,
rhs);
327 Value nan = LLVM::ConstantOp::create(
328 builder, loc,
lhs.getType(),
330 APFloat::getQNaN(floatType.getFloatSemantics())));
331 return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
335 gpu::MMAElementwiseOp op,
338 case gpu::MMAElementwiseOp::ADDF:
339 return LLVM::FAddOp::create(builder, loc, operands[0].
getType(), operands);
340 case gpu::MMAElementwiseOp::MULF:
341 return LLVM::FMulOp::create(builder, loc, operands[0].
getType(), operands);
342 case gpu::MMAElementwiseOp::DIVF:
343 return LLVM::FDivOp::create(builder, loc, operands[0].
getType(), operands);
344 case gpu::MMAElementwiseOp::MAXF:
345 return createMinMaxF(builder, loc, operands[0], operands[1],
347 case gpu::MMAElementwiseOp::MINF:
348 return createMinMaxF(builder, loc, operands[0], operands[1],
351 llvm_unreachable(
"unknown op");
356struct WmmaElementwiseOpToNVVMLowering
358 using ConvertOpToLLVMPattern<
359 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
362 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
364 ConversionPatternRewriter &rewriter)
const override {
366 adaptor.getOperands(), rewriter)))
368 Location loc = subgroupMmaElementwiseOp.getLoc();
369 size_t numOperands = adaptor.getOperands().size();
371 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
374 LLVM::LLVMStructType structDestTy =
375 dyn_cast<LLVM::LLVMStructType>(destType);
377 SmallVector<Value> operands;
378 for (
auto operand : adaptor.getOperands()) {
379 operands.push_back(operand);
381 Value element = createScalarOp(
382 rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
383 rewriter.replaceOp(subgroupMmaElementwiseOp, element);
386 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
387 for (
size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
388 SmallVector<Value> extractedOperands;
389 for (
size_t opIdx = 0; opIdx < numOperands; opIdx++) {
390 extractedOperands.push_back(LLVM::ExtractValueOp::create(
391 rewriter, loc, adaptor.getOperands()[opIdx], i));
394 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
397 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
399 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
408 NVVM::MMAFrag frag = convertOperand(type.
getOperand());
412 std::pair<Type, unsigned> typeInfo =
415 Type f64Ty = Float64Type::get(type.getContext());
416 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
419 return LLVM::LLVMStructType::getLiteral(
426 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
427 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
428 WmmaElementwiseOpToNVVMLowering>(converter, benefit);
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Type getElementType(Type type)
Determine the element type of type.
FloatAttr getFloatAttr(Type type, double value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.