MLIR  21.0.0git
WmmaOpsToNvvm.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // NVVM Dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Checks if all the operands of the op being lowered are of LLVM Types. The
26 /// types are expected to be converted by the `LLVMTypeConverter` before the op
27 /// is actually lowered. If the type of an operands is not already converted it
28 /// hints a missing typeConversion and failure is returned in that case.
29 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30  ConversionPatternRewriter &rewriter) {
31  if (!llvm::all_of(operands, [](Value value) {
32  return LLVM::isCompatibleType(value.getType());
33  })) {
34  return rewriter.notifyMatchFailure(
35  op, "cannot convert if operands aren't of LLVM type.");
36  }
37 
38  return success();
39 }
40 
41 /// Error string to emit when an unimplemented WMMA variant is encountered.
42 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43 
44 static NVVM::MMAFrag convertOperand(StringRef operandName) {
45  if (operandName == "AOp")
46  return NVVM::MMAFrag::a;
47  if (operandName == "BOp")
48  return NVVM::MMAFrag::b;
49  if (operandName == "COp")
50  return NVVM::MMAFrag::c;
51  llvm_unreachable("Unknown operand name");
52 }
53 
54 static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55  if (type.getElementType().isF16())
56  return NVVM::MMATypes::f16;
57  if (type.getElementType().isF32())
58  return type.getOperand() == "COp" ? NVVM::MMATypes::f32
59  : NVVM::MMATypes::tf32;
60 
61  if (type.getElementType().isSignedInteger(8))
62  return NVVM::MMATypes::s8;
63  if (type.getElementType().isUnsignedInteger(8))
64  return NVVM::MMATypes::u8;
65  // Accumulator type is signless and implies signed.
66  if (type.getElementType().isInteger(32))
67  return NVVM::MMATypes::s32;
68  llvm_unreachable("Unsupported type");
69 }
70 
71 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
72 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
73 /// emits code that is necessary to store the data in the destination memref
74 /// after it has been loaded.
75 struct WmmaLoadOpToNVVMLowering
76  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
78  gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
79 
80  LogicalResult
81  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
82  OpAdaptor adaptor,
83  ConversionPatternRewriter &rewriter) const override {
84  Operation *op = subgroupMmaLoadMatrixOp.getOperation();
85  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
86  return failure();
87 
88  // Get the shape of the MMAMatrix type being returned. The shape will
89  // choose which intrinsic this op will be lowered to.
90  NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91  ? NVVM::MMALayout::col
92  : NVVM::MMALayout::row;
93  gpu::MMAMatrixType retType =
94  cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
95  ArrayRef<int64_t> retTypeShape = retType.getShape();
96  int64_t m = 0;
97  int64_t n = 0;
98  int64_t k = 0;
99  NVVM::MMATypes eltype = getElementType(retType);
100  // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101  // dimension based on the valid intrinsics available.
102  if (retType.getOperand() == "AOp") {
103  m = retTypeShape[0];
104  k = retTypeShape[1];
105  n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
106  } else if (retType.getOperand() == "BOp") {
107  k = retTypeShape[0];
108  n = retTypeShape[1];
109  m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
110  } else if (retType.getOperand() == "COp") {
111  m = retTypeShape[0];
112  n = retTypeShape[1];
113  k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
114  }
115  NVVM::MMAFrag frag = convertOperand(retType.getOperand());
116  // Check that there is an exisiting instruction for the combination we need.
117  if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
118  return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
119 
120  Type resType = convertMMAToLLVMType(retType);
121  Location loc = op->getLoc();
122 
123  // Create nvvm.mma_load op according to the operand types.
124  Value dataPtr = getStridedElementPtr(
125  rewriter, loc,
126  cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127  adaptor.getSrcMemref(), adaptor.getIndices());
128 
129  Value leadingDim = rewriter.create<LLVM::ConstantOp>(
130  loc, rewriter.getI32Type(),
131  subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
132  rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
133  op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
134  return success();
135  }
136 };
137 
138 /// This class implements the conversion of GPU MMA storeOp to wmma.store op
139 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
140 /// emits code that is necessary to unpack the data in the source and
141 /// convert the data in the format that is needed by the NVVM op.
142 struct WmmaStoreOpToNVVMLowering
143  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
145  gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
146 
147  LogicalResult
148  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
149  OpAdaptor adaptor,
150  ConversionPatternRewriter &rewriter) const override {
151  Operation *op = subgroupMmaStoreMatrixOp.getOperation();
152  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
153  return failure();
154 
155  Location loc = op->getLoc();
156 
157  SmallVector<Value, 4> storeOpOperands;
158  // Get the shape of the MMAMatrix type being stored. The shape will
159  // choose which intrinsic this op will be lowered to.
160  gpu::MMAMatrixType srcType =
161  cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
162  ArrayRef<int64_t> srcTypeShape = srcType.getShape();
163  NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
164  ? NVVM::MMALayout::col
165  : NVVM::MMALayout::row;
166  NVVM::MMATypes eltype = getElementType(srcType);
167  int64_t m = srcTypeShape[0];
168  int64_t n = srcTypeShape[1];
169  int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
170  if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
171  return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
172 
173  auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
174  for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
175  Value toUse =
176  rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
177  storeOpOperands.push_back(toUse);
178  }
179 
180  Value dataPtr = getStridedElementPtr(
181  rewriter, loc,
182  cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
183  adaptor.getDstMemref(), adaptor.getIndices());
184  Value leadingDim = rewriter.create<LLVM::ConstantOp>(
185  loc, rewriter.getI32Type(),
186  subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
187  rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
188  op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
189  return success();
190  }
191 };
192 
193 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op
194 /// in the NVVM dialect.
195 struct WmmaMmaOpToNVVMLowering
196  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
198  gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
199 
200  LogicalResult
201  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
202  OpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override {
204  Operation *op = subgroupMmaComputeOp.getOperation();
205  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
206  return failure();
207 
208  Location loc = op->getLoc();
209 
210  // The wmma.mma intrinsic in llvm requires the operands as individual
211  // values. So individual elements from the memrefs need to be extracted and
212  // then passed on to the intrinsic call. Emit llvm ops to extract individual
213  // values form lowered memrefs.
214  SmallVector<Value> unpackedOps;
215 
216  auto unpackOp = [&](Value operand) {
217  auto structType = cast<LLVM::LLVMStructType>(operand.getType());
218  for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
219  Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
220  unpackedOps.push_back(toUse);
221  }
222  };
223 
224  // Get the shapes of the MMAMatrix type being used. The shapes will
225  // choose which intrinsic this op will be lowered to.
226  gpu::MMAMatrixType aType =
227  cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
228  ArrayRef<int64_t> aTypeShape = aType.getShape();
229  gpu::MMAMatrixType cType =
230  cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
231  ArrayRef<int64_t> cTypeShape = cType.getShape();
232  int64_t m = cTypeShape[0];
233  int64_t n = cTypeShape[1];
234  int64_t k = aTypeShape[1];
235  NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
236  ? NVVM::MMALayout::col
237  : NVVM::MMALayout::row;
238  NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
239  ? NVVM::MMALayout::col
240  : NVVM::MMALayout::row;
241  NVVM::MMATypes sourceType = getElementType(aType);
242  NVVM::MMATypes destType = getElementType(cType);
243  if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
244  destType) == 0)
245  return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
246 
247  NVVM::MMATypes bElementType = getElementType(
248  cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
249  if (bElementType != sourceType)
250  return rewriter.notifyMatchFailure(
251  op, "WMMA compute op input matrix element types must match.");
252 
253  unpackOp(adaptor.getOpA());
254  unpackOp(adaptor.getOpB());
255  unpackOp(adaptor.getOpC());
256 
257  rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
258  op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
259  destType, unpackedOps);
260  return success();
261  }
262 };
263 
264 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
265 struct WmmaConstantOpToNVVMLowering
266  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
268  gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
269 
270  LogicalResult
271  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
272  OpAdaptor adaptor,
273  ConversionPatternRewriter &rewriter) const override {
274  if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
275  adaptor.getOperands(), rewriter)))
276  return failure();
277  Location loc = subgroupMmaConstantOp.getLoc();
278  Value cst = adaptor.getOperands()[0];
279  LLVM::LLVMStructType type = convertMMAToLLVMType(
280  cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
281  // If the element type is a vector create a vector from the operand.
282  if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
283  Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
284  for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
285  Value idx = rewriter.create<LLVM::ConstantOp>(
286  loc, rewriter.getI32Type(), vecEl);
287  vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
288  cst, idx);
289  }
290  cst = vecCst;
291  }
292  Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
293  for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
294  matrixStruct =
295  rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
296  }
297  rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
298  return success();
299  }
300 };
301 
302 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
303  Value rhs, bool isMin) {
304  auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
305  Type i1Type = builder.getI1Type();
306  if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
307  i1Type = VectorType::get(vecType.getShape(), i1Type);
308  Value cmp = builder.create<LLVM::FCmpOp>(
309  loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
310  lhs, rhs);
311  Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
312  Value isNan = builder.create<LLVM::FCmpOp>(
313  loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
314  Value nan = builder.create<LLVM::ConstantOp>(
315  loc, lhs.getType(),
316  builder.getFloatAttr(floatType,
317  APFloat::getQNaN(floatType.getFloatSemantics())));
318  return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
319 }
320 
321 static Value createScalarOp(OpBuilder &builder, Location loc,
322  gpu::MMAElementwiseOp op,
323  ArrayRef<Value> operands) {
324  switch (op) {
325  case gpu::MMAElementwiseOp::ADDF:
326  return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
327  case gpu::MMAElementwiseOp::MULF:
328  return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
329  case gpu::MMAElementwiseOp::DIVF:
330  return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
331  case gpu::MMAElementwiseOp::MAXF:
332  return createMinMaxF(builder, loc, operands[0], operands[1],
333  /*isMin=*/false);
334  case gpu::MMAElementwiseOp::MINF:
335  return createMinMaxF(builder, loc, operands[0], operands[1],
336  /*isMin=*/true);
337  default:
338  llvm_unreachable("unknown op");
339  }
340 }
341 
342 /// Convert GPU MMA elementwise ops to extract + op + insert.
343 struct WmmaElementwiseOpToNVVMLowering
344  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
346  gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
347 
348  LogicalResult
349  matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
350  OpAdaptor adaptor,
351  ConversionPatternRewriter &rewriter) const override {
352  if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
353  adaptor.getOperands(), rewriter)))
354  return failure();
355  Location loc = subgroupMmaElementwiseOp.getLoc();
356  size_t numOperands = adaptor.getOperands().size();
357  LLVM::LLVMStructType destType = convertMMAToLLVMType(
358  cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
359  Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
360  for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
361  SmallVector<Value> extractedOperands;
362  for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
363  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
364  loc, adaptor.getOperands()[opIdx], i));
365  }
366  Value element =
367  createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
368  extractedOperands);
369  matrixStruct =
370  rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
371  }
372  rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
373  return success();
374  }
375 };
376 
377 } // namespace
378 
379 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
380 LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
381  NVVM::MMAFrag frag = convertOperand(type.getOperand());
382  NVVM::MMATypes eltType = getElementType(type);
383  auto nRow = type.getShape()[0];
384  auto nCol = type.getShape()[1];
385  std::pair<Type, unsigned> typeInfo =
386  NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
387  return LLVM::LLVMStructType::getLiteral(
388  type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
389 }
390 
392  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
393  PatternBenefit benefit) {
394  patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
395  WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
396  WmmaElementwiseOpToNVVMLowering>(converter, benefit);
397 }
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getI1Type()
Definition: Builders.cpp:53
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:191
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:140
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:144
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:146
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:487
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.