MLIR 23.0.0git
WmmaOpsToNvvm.cpp
Go to the documentation of this file.
1//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// NVVM Dialect.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/IR/Types.h"
21
22using namespace mlir;
23
24namespace {
25
26/// Checks if all the operands of the op being lowered are of LLVM Types. The
27/// types are expected to be converted by the `LLVMTypeConverter` before the op
28/// is actually lowered. If the type of an operands is not already converted it
29/// hints a missing typeConversion and failure is returned in that case.
30static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
31 ConversionPatternRewriter &rewriter) {
32 if (!llvm::all_of(operands, [](Value value) {
33 return LLVM::isCompatibleType(value.getType());
34 })) {
35 return rewriter.notifyMatchFailure(
36 op, "cannot convert if operands aren't of LLVM type.");
37 }
38
39 return success();
40}
41
42/// Error string to emit when an unimplemented WMMA variant is encountered.
43static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
44
45static NVVM::MMAFrag convertOperand(StringRef operandName) {
46 if (operandName == "AOp")
47 return NVVM::MMAFrag::a;
48 if (operandName == "BOp")
49 return NVVM::MMAFrag::b;
50 if (operandName == "COp")
51 return NVVM::MMAFrag::c;
52 llvm_unreachable("Unknown operand name");
53}
54
55static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
56 if (type.getElementType().isF16())
57 return NVVM::MMATypes::f16;
58 if (type.getElementType().isF32())
59 return type.getOperand() == "COp" ? NVVM::MMATypes::f32
60 : NVVM::MMATypes::tf32;
61 if (type.getElementType().isF64())
62 return NVVM::MMATypes::f64;
63 if (type.getElementType().isSignedInteger(8))
64 return NVVM::MMATypes::s8;
66 return NVVM::MMATypes::u8;
67 // Accumulator type is signless and implies signed.
68 if (type.getElementType().isInteger(32))
69 return NVVM::MMATypes::s32;
70 llvm_unreachable("Unsupported type");
71}
72
73/// This class implements the conversion of GPU MMA loadOp to wmma.load op
74/// in the NVVM dialect. The conversion not only emits the NVVM op but also
75/// emits code that is necessary to store the data in the destination memref
76/// after it has been loaded.
77struct WmmaLoadOpToNVVMLowering
78 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
79 using ConvertOpToLLVMPattern<
80 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
81
82 LogicalResult
83 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
84 OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
87 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
88 return failure();
89
90 // Get the shape of the MMAMatrix type being returned. The shape will
91 // choose which intrinsic this op will be lowered to.
92 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
93 ? NVVM::MMALayout::col
94 : NVVM::MMALayout::row;
95 gpu::MMAMatrixType retType =
96 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
97 ArrayRef<int64_t> retTypeShape = retType.getShape();
98 int64_t m = 0;
99 int64_t n = 0;
100 int64_t k = 0;
101 NVVM::MMATypes eltype = getElementType(retType);
102 // NVVM intrinsics require to give mxnxk dimensions, infer the missing
103 // dimension based on the valid intrinsics available.
104 if (retType.getOperand() == "AOp") {
105 m = retTypeShape[0];
106 k = retTypeShape[1];
107 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
108 } else if (retType.getOperand() == "BOp") {
109 k = retTypeShape[0];
110 n = retTypeShape[1];
111 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
112 } else if (retType.getOperand() == "COp") {
113 m = retTypeShape[0];
114 n = retTypeShape[1];
115 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
116 }
117 NVVM::MMAFrag frag = convertOperand(retType.getOperand());
118 // Check that there is an exisiting instruction for the combination we need.
119 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
120 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
121
122 Type resType = convertMMAToLLVMType(retType);
123 Location loc = op->getLoc();
124
125 // Create nvvm.mma_load op according to the operand types.
126 Value dataPtr = getStridedElementPtr(
127 rewriter, loc,
128 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
129 adaptor.getSrcMemref(), adaptor.getIndices());
130
131 Value leadingDim = LLVM::ConstantOp::create(
132 rewriter, loc, rewriter.getI32Type(),
133 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
134 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
135 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
136 return success();
137 }
138};
139
140/// This class implements the conversion of GPU MMA storeOp to wmma.store op
141/// in the NVVM dialect. The conversion not only emits the NVVM op but also
142/// emits code that is necessary to unpack the data in the source and
143/// convert the data in the format that is needed by the NVVM op.
144struct WmmaStoreOpToNVVMLowering
145 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
146 using ConvertOpToLLVMPattern<
147 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
148
149 LogicalResult
150 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
151 OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
154 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
155 return failure();
156
157 Location loc = op->getLoc();
158
159 SmallVector<Value, 4> storeOpOperands;
160 // Get the shape of the MMAMatrix type being stored. The shape will
161 // choose which intrinsic this op will be lowered to.
162 gpu::MMAMatrixType srcType =
163 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
164 ArrayRef<int64_t> srcTypeShape = srcType.getShape();
165 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
166 ? NVVM::MMALayout::col
167 : NVVM::MMALayout::row;
168 NVVM::MMATypes eltype = getElementType(srcType);
169 int64_t m = srcTypeShape[0];
170 int64_t n = srcTypeShape[1];
171 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
172 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
173 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
174
175 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
176 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
177 Value toUse =
178 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
179 storeOpOperands.push_back(toUse);
180 }
181
182 Value dataPtr = getStridedElementPtr(
183 rewriter, loc,
184 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
185 adaptor.getDstMemref(), adaptor.getIndices());
186 Value leadingDim = LLVM::ConstantOp::create(
187 rewriter, loc, rewriter.getI32Type(),
188 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
189 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
190 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
191 return success();
192 }
193};
194
195/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
196/// in the NVVM dialect.
197struct WmmaMmaOpToNVVMLowering
198 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
199 using ConvertOpToLLVMPattern<
200 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
201
202 LogicalResult
203 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
204 OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 Operation *op = subgroupMmaComputeOp.getOperation();
207 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
208 return failure();
209
210 Location loc = op->getLoc();
211
212 // The wmma.mma intrinsic in llvm requires the operands as individual
213 // values. So individual elements from the memrefs need to be extracted and
214 // then passed on to the intrinsic call. Emit llvm ops to extract individual
215 // values form lowered memrefs.
216 SmallVector<Value> unpackedOps;
217 auto unpackOp = [&](Value operand) {
218 // f64 a and b fragments are not structs but scalars.
219 if (!isa<LLVM::LLVMStructType>(operand.getType())) {
220 unpackedOps.push_back(operand);
221 return;
222 }
223 // every other type is lowered to an LLVM struct, extract the values.
224 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
225 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
226 Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
227 unpackedOps.push_back(toUse);
228 }
229 };
230
231 // Get the shapes of the MMAMatrix type being used. The shapes will
232 // choose which intrinsic this op will be lowered to.
233 gpu::MMAMatrixType aType =
234 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
235 ArrayRef<int64_t> aTypeShape = aType.getShape();
236 gpu::MMAMatrixType cType =
237 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
238 ArrayRef<int64_t> cTypeShape = cType.getShape();
239 int64_t m = cTypeShape[0];
240 int64_t n = cTypeShape[1];
241 int64_t k = aTypeShape[1];
242 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
243 ? NVVM::MMALayout::col
244 : NVVM::MMALayout::row;
245 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
246 ? NVVM::MMALayout::col
247 : NVVM::MMALayout::row;
248 NVVM::MMATypes sourceType = getElementType(aType);
249 NVVM::MMATypes destType = getElementType(cType);
250 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
251 destType) == 0)
252 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
253
254 NVVM::MMATypes bElementType = getElementType(
255 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
256 if (bElementType != sourceType)
257 return rewriter.notifyMatchFailure(
258 op, "WMMA compute op input matrix element types must match.");
259
260 unpackOp(adaptor.getOpA());
261 unpackOp(adaptor.getOpB());
262 unpackOp(adaptor.getOpC());
263
264 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
265 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
266 destType, unpackedOps);
267 return success();
268 }
269};
270
271/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
272struct WmmaConstantOpToNVVMLowering
273 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
274 using ConvertOpToLLVMPattern<
275 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
276
277 LogicalResult
278 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
279 OpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
282 adaptor.getOperands(), rewriter)))
283 return failure();
284 Location loc = subgroupMmaConstantOp.getLoc();
285 Value cst = adaptor.getOperands()[0];
286 Type type = convertMMAToLLVMType(
287 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
288 // If the element is not a struct, it means it's a scalar f64.
289 auto structType = dyn_cast<LLVM::LLVMStructType>(type);
290 if (!structType) {
291 rewriter.replaceOp(subgroupMmaConstantOp, cst);
292 return success();
293 }
294 // If the element type is a vector create a vector from the operand.
295 if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
296 Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
297 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
298 Value idx = LLVM::ConstantOp::create(rewriter, loc,
299 rewriter.getI32Type(), vecEl);
300 vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
301 cst, idx);
302 }
303 cst = vecCst;
304 }
305 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
306 for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
307 matrixStruct =
308 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
309 }
310 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
311 return success();
312 }
313};
314
315static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
316 Value rhs, bool isMin) {
317 auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
318 Type i1Type = builder.getI1Type();
319 if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
320 i1Type = VectorType::get(vecType.getShape(), i1Type);
321 Value cmp = LLVM::FCmpOp::create(
322 builder, loc, i1Type,
323 isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs);
324 Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs);
325 Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
326 LLVM::FCmpPredicate::uno, lhs, rhs);
327 Value nan = LLVM::ConstantOp::create(
328 builder, loc, lhs.getType(),
329 builder.getFloatAttr(floatType,
330 APFloat::getQNaN(floatType.getFloatSemantics())));
331 return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
332}
333
334static Value createScalarOp(OpBuilder &builder, Location loc,
335 gpu::MMAElementwiseOp op,
336 ArrayRef<Value> operands) {
337 switch (op) {
338 case gpu::MMAElementwiseOp::ADDF:
339 return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands);
340 case gpu::MMAElementwiseOp::MULF:
341 return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands);
342 case gpu::MMAElementwiseOp::DIVF:
343 return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands);
344 case gpu::MMAElementwiseOp::MAXF:
345 return createMinMaxF(builder, loc, operands[0], operands[1],
346 /*isMin=*/false);
347 case gpu::MMAElementwiseOp::MINF:
348 return createMinMaxF(builder, loc, operands[0], operands[1],
349 /*isMin=*/true);
350 default:
351 llvm_unreachable("unknown op");
352 }
353}
354
355/// Convert GPU MMA elementwise ops to extract + op + insert.
356struct WmmaElementwiseOpToNVVMLowering
357 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
358 using ConvertOpToLLVMPattern<
359 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
360
361 LogicalResult
362 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
363 OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter) const override {
365 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
366 adaptor.getOperands(), rewriter)))
367 return failure();
368 Location loc = subgroupMmaElementwiseOp.getLoc();
369 size_t numOperands = adaptor.getOperands().size();
370 Type destType = convertMMAToLLVMType(
371 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
372
373 // If the element is not a struct, it means it's a scalar f64.
374 LLVM::LLVMStructType structDestTy =
375 dyn_cast<LLVM::LLVMStructType>(destType);
376 if (!structDestTy) {
377 SmallVector<Value> operands;
378 for (auto operand : adaptor.getOperands()) {
379 operands.push_back(operand);
380 }
381 Value element = createScalarOp(
382 rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
383 rewriter.replaceOp(subgroupMmaElementwiseOp, element);
384 return success();
385 }
386 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
387 for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
388 SmallVector<Value> extractedOperands;
389 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
390 extractedOperands.push_back(LLVM::ExtractValueOp::create(
391 rewriter, loc, adaptor.getOperands()[opIdx], i));
392 }
393 Value element =
394 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
395 extractedOperands);
396 matrixStruct =
397 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
398 }
399 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
400 return success();
401 }
402};
403
404} // namespace
405
406/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
408 NVVM::MMAFrag frag = convertOperand(type.getOperand());
409 NVVM::MMATypes eltType = getElementType(type);
410 auto nRow = type.getShape()[0];
411 auto nCol = type.getShape()[1];
412 std::pair<Type, unsigned> typeInfo =
413 NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
414 // Special handling for f64 a and b fragments
415 Type f64Ty = Float64Type::get(type.getContext());
416 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
417 return f64Ty;
418 }
419 return LLVM::LLVMStructType::getLiteral(
420 type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
421}
422
424 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
425 PatternBenefit benefit) {
426 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
427 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
428 WmmaElementwiseOpToNVVMLowering>(converter, benefit);
429}
return success()
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
lhs
static Type getElementType(Type type)
Determine the element type of type.
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getI1Type()
Definition Builders.cpp:57
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.