MLIR 22.0.0git
WmmaOpsToNvvm.cpp
Go to the documentation of this file.
1//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// NVVM Dialect.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/IR/Types.h"
21
22using namespace mlir;
23
24namespace {
25
26/// Checks if all the operands of the op being lowered are of LLVM Types. The
27/// types are expected to be converted by the `LLVMTypeConverter` before the op
28/// is actually lowered. If the type of an operands is not already converted it
29/// hints a missing typeConversion and failure is returned in that case.
30static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
31 ConversionPatternRewriter &rewriter) {
32 if (!llvm::all_of(operands, [](Value value) {
33 return LLVM::isCompatibleType(value.getType());
34 })) {
35 return rewriter.notifyMatchFailure(
36 op, "cannot convert if operands aren't of LLVM type.");
37 }
38
39 return success();
40}
41
42/// Error string to emit when an unimplemented WMMA variant is encountered.
43static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
44
45static NVVM::MMAFrag convertOperand(StringRef operandName) {
46 if (operandName == "AOp")
47 return NVVM::MMAFrag::a;
48 if (operandName == "BOp")
49 return NVVM::MMAFrag::b;
50 if (operandName == "COp")
51 return NVVM::MMAFrag::c;
52 llvm_unreachable("Unknown operand name");
53}
54
55static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
56 if (type.getElementType().isF16())
57 return NVVM::MMATypes::f16;
58 if (type.getElementType().isF32())
59 return type.getOperand() == "COp" ? NVVM::MMATypes::f32
60 : NVVM::MMATypes::tf32;
61 if (type.getElementType().isF64())
62 return NVVM::MMATypes::f64;
63 if (type.getElementType().isSignedInteger(8))
64 return NVVM::MMATypes::s8;
66 return NVVM::MMATypes::u8;
67 // Accumulator type is signless and implies signed.
68 if (type.getElementType().isInteger(32))
69 return NVVM::MMATypes::s32;
70 llvm_unreachable("Unsupported type");
71}
72
73/// This class implements the conversion of GPU MMA loadOp to wmma.load op
74/// in the NVVM dialect. The conversion not only emits the NVVM op but also
75/// emits code that is necessary to store the data in the destination memref
76/// after it has been loaded.
77struct WmmaLoadOpToNVVMLowering
78 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
79 using ConvertOpToLLVMPattern<
80 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
81
82 LogicalResult
83 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
84 OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
87 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
88 return failure();
89
90 // Get the shape of the MMAMatrix type being returned. The shape will
91 // choose which intrinsic this op will be lowered to.
92 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
93 ? NVVM::MMALayout::col
94 : NVVM::MMALayout::row;
95 gpu::MMAMatrixType retType =
96 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
97 ArrayRef<int64_t> retTypeShape = retType.getShape();
98 int64_t m = 0;
99 int64_t n = 0;
100 int64_t k = 0;
101 NVVM::MMATypes eltype = getElementType(retType);
102 // NVVM intrinsics require to give mxnxk dimensions, infer the missing
103 // dimension based on the valid intrinsics available.
104 if (retType.getOperand() == "AOp") {
105 m = retTypeShape[0];
106 k = retTypeShape[1];
107 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
108 } else if (retType.getOperand() == "BOp") {
109 k = retTypeShape[0];
110 n = retTypeShape[1];
111 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
112 } else if (retType.getOperand() == "COp") {
113 m = retTypeShape[0];
114 n = retTypeShape[1];
115 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
116 }
117 NVVM::MMAFrag frag = convertOperand(retType.getOperand());
118 // Check that there is an exisiting instruction for the combination we need.
119 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
120 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
121
122 Type resType = convertMMAToLLVMType(retType);
123 Location loc = op->getLoc();
124
125 // Create nvvm.mma_load op according to the operand types.
126 Value dataPtr = getStridedElementPtr(
127 rewriter, loc,
128 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
129 adaptor.getSrcMemref(), adaptor.getIndices());
130
131 Value leadingDim = LLVM::ConstantOp::create(
132 rewriter, loc, rewriter.getI32Type(),
133 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
134 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
135 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
136 return success();
137 }
138};
139
140/// This class implements the conversion of GPU MMA storeOp to wmma.store op
141/// in the NVVM dialect. The conversion not only emits the NVVM op but also
142/// emits code that is necessary to unpack the data in the source and
143/// convert the data in the format that is needed by the NVVM op.
144struct WmmaStoreOpToNVVMLowering
145 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
146 using ConvertOpToLLVMPattern<
147 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
148
149 LogicalResult
150 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
151 OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
154 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
155 return failure();
156
157 Location loc = op->getLoc();
158
159 SmallVector<Value, 4> storeOpOperands;
160 // Get the shape of the MMAMatrix type being stored. The shape will
161 // choose which intrinsic this op will be lowered to.
162 gpu::MMAMatrixType srcType =
163 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
164 ArrayRef<int64_t> srcTypeShape = srcType.getShape();
165 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
166 ? NVVM::MMALayout::col
167 : NVVM::MMALayout::row;
168 NVVM::MMATypes eltype = getElementType(srcType);
169 int64_t m = srcTypeShape[0];
170 int64_t n = srcTypeShape[1];
171 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
172 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
173 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
174
175 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
176 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
177 Value toUse =
178 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
179 storeOpOperands.push_back(toUse);
180 }
181
182 Value dataPtr = getStridedElementPtr(
183 rewriter, loc,
184 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
185 adaptor.getDstMemref(), adaptor.getIndices());
186 Value leadingDim = LLVM::ConstantOp::create(
187 rewriter, loc, rewriter.getI32Type(),
188 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
189 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
190 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
191 return success();
192 }
193};
194
195/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
196/// in the NVVM dialect.
197struct WmmaMmaOpToNVVMLowering
198 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
199 using ConvertOpToLLVMPattern<
200 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
201
202 LogicalResult
203 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
204 OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 Operation *op = subgroupMmaComputeOp.getOperation();
207 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
208 return failure();
209
210 Location loc = op->getLoc();
211
212 // The wmma.mma intrinsic in llvm requires the operands as individual
213 // values. So individual elements from the memrefs need to be extracted and
214 // then passed on to the intrinsic call. Emit llvm ops to extract individual
215 // values form lowered memrefs.
216 SmallVector<Value> unpackedOps;
217 auto unpackOp = [&](Value operand) {
218 // f64 a and b fragments are not structs but scalars.
219 if (!isa<LLVM::LLVMStructType>(operand.getType())) {
220 unpackedOps.push_back(operand);
221 return;
222 }
223 // every other type is lowered to an LLVM struct, extract the values.
224 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
225 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
226 Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
227 unpackedOps.push_back(toUse);
228 }
229 };
230
231 // Get the shapes of the MMAMatrix type being used. The shapes will
232 // choose which intrinsic this op will be lowered to.
233 gpu::MMAMatrixType aType =
234 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
235 ArrayRef<int64_t> aTypeShape = aType.getShape();
236 gpu::MMAMatrixType cType =
237 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
238 ArrayRef<int64_t> cTypeShape = cType.getShape();
239 int64_t m = cTypeShape[0];
240 int64_t n = cTypeShape[1];
241 int64_t k = aTypeShape[1];
242 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
243 ? NVVM::MMALayout::col
244 : NVVM::MMALayout::row;
245 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
246 ? NVVM::MMALayout::col
247 : NVVM::MMALayout::row;
248 NVVM::MMATypes sourceType = getElementType(aType);
249 NVVM::MMATypes destType = getElementType(cType);
250 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
251 destType) == 0)
252 return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
253
254 NVVM::MMATypes bElementType = getElementType(
255 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
256 if (bElementType != sourceType)
257 return rewriter.notifyMatchFailure(
258 op, "WMMA compute op input matrix element types must match.");
259
260 unpackOp(adaptor.getOpA());
261 unpackOp(adaptor.getOpB());
262 unpackOp(adaptor.getOpC());
263
264 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
265 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
266 destType, unpackedOps);
267 return success();
268 }
269};
270
271/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
272struct WmmaConstantOpToNVVMLowering
273 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
274 using ConvertOpToLLVMPattern<
275 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
276
277 LogicalResult
278 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
279 OpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
282 adaptor.getOperands(), rewriter)))
283 return failure();
284 Location loc = subgroupMmaConstantOp.getLoc();
285 Value cst = adaptor.getOperands()[0];
286 Type type = convertMMAToLLVMType(
287 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
288 // If the element is not a struct, it means it's a scalar f64.
289 auto structType = dyn_cast<LLVM::LLVMStructType>(type);
290 if (!structType) {
291 rewriter.replaceOp(subgroupMmaConstantOp, cst);
292 return success();
293 }
294 // If the element type is a vector create a vector from the operand.
295 if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
296 Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
297 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
298 Value idx = LLVM::ConstantOp::create(rewriter, loc,
299 rewriter.getI32Type(), vecEl);
300 vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
301 cst, idx);
302 }
303 cst = vecCst;
304 }
305 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
306 for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
307 matrixStruct =
308 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
309 }
310 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
311 return success();
312 }
313};
314
315static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
316 Value rhs, bool isMin) {
317 auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
318 Type i1Type = builder.getI1Type();
319 if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
320 i1Type = VectorType::get(vecType.getShape(), i1Type);
321 Value cmp = LLVM::FCmpOp::create(
322 builder, loc, i1Type,
323 isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs);
324 Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs);
325 Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
326 LLVM::FCmpPredicate::uno, lhs, rhs);
327 Value nan = LLVM::ConstantOp::create(
328 builder, loc, lhs.getType(),
329 builder.getFloatAttr(floatType,
330 APFloat::getQNaN(floatType.getFloatSemantics())));
331 return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
332}
333
334static Value createScalarOp(OpBuilder &builder, Location loc,
335 gpu::MMAElementwiseOp op,
336 ArrayRef<Value> operands) {
337 switch (op) {
338 case gpu::MMAElementwiseOp::ADDF:
339 return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands);
340 case gpu::MMAElementwiseOp::MULF:
341 return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands);
342 case gpu::MMAElementwiseOp::DIVF:
343 return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands);
344 case gpu::MMAElementwiseOp::MAXF:
345 return createMinMaxF(builder, loc, operands[0], operands[1],
346 /*isMin=*/false);
347 case gpu::MMAElementwiseOp::MINF:
348 return createMinMaxF(builder, loc, operands[0], operands[1],
349 /*isMin=*/true);
350 default:
351 llvm_unreachable("unknown op");
352 }
353}
354
355/// Convert GPU MMA elementwise ops to extract + op + insert.
356struct WmmaElementwiseOpToNVVMLowering
357 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
358 using ConvertOpToLLVMPattern<
359 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
360
361 LogicalResult
362 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
363 OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter) const override {
365 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
366 adaptor.getOperands(), rewriter)))
367 return failure();
368 Location loc = subgroupMmaElementwiseOp.getLoc();
369 size_t numOperands = adaptor.getOperands().size();
370 Type destType = convertMMAToLLVMType(
371 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
372
373 // If the element is not a struct, it means it's a scalar f64.
374 LLVM::LLVMStructType structDestTy =
375 dyn_cast<LLVM::LLVMStructType>(destType);
376 if (!structDestTy) {
377 SmallVector<Value> operands;
378 for (auto operand : adaptor.getOperands()) {
379 operands.push_back(operand);
380 }
381 Value element = createScalarOp(
382 rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
383 rewriter.replaceOp(subgroupMmaElementwiseOp, element);
384 return success();
385 }
386 Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
387 for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
388 SmallVector<Value> extractedOperands;
389 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
390 extractedOperands.push_back(LLVM::ExtractValueOp::create(
391 rewriter, loc, adaptor.getOperands()[opIdx], i));
392 }
393 Value element =
394 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
395 extractedOperands);
396 matrixStruct =
397 LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
398 }
399 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
400 return success();
401 }
402};
403
404} // namespace
405
406/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
408 NVVM::MMAFrag frag = convertOperand(type.getOperand());
409 NVVM::MMATypes eltType = getElementType(type);
410 auto nRow = type.getShape()[0];
411 auto nCol = type.getShape()[1];
412 std::pair<Type, unsigned> typeInfo =
413 NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
414 // Special handling for f64 a and b fragments
415 Type f64Ty = Float64Type::get(type.getContext());
416 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
417 return f64Ty;
418 }
419 return LLVM::LLVMStructType::getLiteral(
420 type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
421}
422
425 PatternBenefit benefit) {
426 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
427 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
428 WmmaElementwiseOpToNVVMLowering>(converter, benefit);
429}
return success()
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
lhs
static Type getElementType(Type type)
Determine the element type of type.
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getI1Type()
Definition Builders.cpp:53
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.