MLIR  20.0.0git
WmmaOpsToNvvm.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // NVVM Dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Checks if all the operands of the op being lowered are of LLVM Types. The
26 /// types are expected to be converted by the `LLVMTypeConverter` before the op
27 /// is actually lowered. If the type of an operands is not already converted it
28 /// hints a missing typeConversion and failure is returned in that case.
29 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30  ConversionPatternRewriter &rewriter) {
31  if (!llvm::all_of(operands, [](Value value) {
32  return LLVM::isCompatibleType(value.getType());
33  })) {
34  return rewriter.notifyMatchFailure(
35  op, "cannot convert if operands aren't of LLVM type.");
36  }
37 
38  return success();
39 }
40 
41 /// Error string to emit when an unimplemented WMMA variant is encountered.
42 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43 
44 static NVVM::MMAFrag convertOperand(StringRef operandName) {
45  if (operandName == "AOp")
46  return NVVM::MMAFrag::a;
47  if (operandName == "BOp")
48  return NVVM::MMAFrag::b;
49  if (operandName == "COp")
50  return NVVM::MMAFrag::c;
51  llvm_unreachable("Unknown operand name");
52 }
53 
54 static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55  if (type.getElementType().isF16())
56  return NVVM::MMATypes::f16;
57  if (type.getElementType().isF32())
58  return type.getOperand() == "COp" ? NVVM::MMATypes::f32
59  : NVVM::MMATypes::tf32;
60 
61  if (type.getElementType().isSignedInteger(8))
62  return NVVM::MMATypes::s8;
63  if (type.getElementType().isUnsignedInteger(8))
64  return NVVM::MMATypes::u8;
65  // Accumulator type is signless and implies signed.
66  if (type.getElementType().isInteger(32))
67  return NVVM::MMATypes::s32;
68  llvm_unreachable("Unsupported type");
69 }
70 
71 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
72 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
73 /// emits code that is necessary to store the data in the destination memref
74 /// after it has been loaded.
75 struct WmmaLoadOpToNVVMLowering
76  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
78  gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
79 
80  LogicalResult
81  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
82  OpAdaptor adaptor,
83  ConversionPatternRewriter &rewriter) const override {
84  Operation *op = subgroupMmaLoadMatrixOp.getOperation();
85  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
86  return failure();
87 
88  // Get the shape of the MMAMatrix type being returned. The shape will
89  // choose which intrinsic this op will be lowered to.
90  NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91  ? NVVM::MMALayout::col
92  : NVVM::MMALayout::row;
93  gpu::MMAMatrixType retType =
94  cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
95  ArrayRef<int64_t> retTypeShape = retType.getShape();
96  int64_t m = 0;
97  int64_t n = 0;
98  int64_t k = 0;
99  NVVM::MMATypes eltype = getElementType(retType);
100  // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101  // dimension based on the valid intrinsics available.
102  if (retType.getOperand() == "AOp") {
103  m = retTypeShape[0];
104  k = retTypeShape[1];
105  n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
106  } else if (retType.getOperand() == "BOp") {
107  k = retTypeShape[0];
108  n = retTypeShape[1];
109  m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
110  } else if (retType.getOperand() == "COp") {
111  m = retTypeShape[0];
112  n = retTypeShape[1];
113  k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
114  }
115  NVVM::MMAFrag frag = convertOperand(retType.getOperand());
116  // Check that there is an exisiting instruction for the combination we need.
117  if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
118  return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
119 
120  Type resType = convertMMAToLLVMType(retType);
121  Location loc = op->getLoc();
122 
123  // Create nvvm.mma_load op according to the operand types.
124  Value dataPtr = getStridedElementPtr(
125  loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126  adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
127 
128  Value leadingDim = rewriter.create<LLVM::ConstantOp>(
129  loc, rewriter.getI32Type(),
130  subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
131  rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
132  op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
133  return success();
134  }
135 };
136 
137 /// This class implements the conversion of GPU MMA storeOp to wmma.store op
138 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
139 /// emits code that is necessary to unpack the data in the source and
140 /// convert the data in the format that is needed by the NVVM op.
141 struct WmmaStoreOpToNVVMLowering
142  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
144  gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
145 
146  LogicalResult
147  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
148  OpAdaptor adaptor,
149  ConversionPatternRewriter &rewriter) const override {
150  Operation *op = subgroupMmaStoreMatrixOp.getOperation();
151  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
152  return failure();
153 
154  Location loc = op->getLoc();
155 
156  SmallVector<Value, 4> storeOpOperands;
157  // Get the shape of the MMAMatrix type being stored. The shape will
158  // choose which intrinsic this op will be lowered to.
159  gpu::MMAMatrixType srcType =
160  cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
161  ArrayRef<int64_t> srcTypeShape = srcType.getShape();
162  NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
163  ? NVVM::MMALayout::col
164  : NVVM::MMALayout::row;
165  NVVM::MMATypes eltype = getElementType(srcType);
166  int64_t m = srcTypeShape[0];
167  int64_t n = srcTypeShape[1];
168  int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
169  if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
170  return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
171 
172  auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
173  for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
174  Value toUse =
175  rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
176  storeOpOperands.push_back(toUse);
177  }
178 
179  Value dataPtr = getStridedElementPtr(
180  loc,
181  cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182  adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183  Value leadingDim = rewriter.create<LLVM::ConstantOp>(
184  loc, rewriter.getI32Type(),
185  subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
186  rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
187  op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
188  return success();
189  }
190 };
191 
192 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op
193 /// in the NVVM dialect.
194 struct WmmaMmaOpToNVVMLowering
195  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
197  gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
198 
199  LogicalResult
200  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
201  OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override {
203  Operation *op = subgroupMmaComputeOp.getOperation();
204  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
205  return failure();
206 
207  Location loc = op->getLoc();
208 
209  // The wmma.mma intrinsic in llvm requires the operands as individual
210  // values. So individual elements from the memrefs need to be extracted and
211  // then passed on to the intrinsic call. Emit llvm ops to extract individual
212  // values form lowered memrefs.
213  SmallVector<Value> unpackedOps;
214 
215  auto unpackOp = [&](Value operand) {
216  auto structType = cast<LLVM::LLVMStructType>(operand.getType());
217  for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
218  Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
219  unpackedOps.push_back(toUse);
220  }
221  };
222 
223  // Get the shapes of the MMAMatrix type being used. The shapes will
224  // choose which intrinsic this op will be lowered to.
225  gpu::MMAMatrixType aType =
226  cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
227  ArrayRef<int64_t> aTypeShape = aType.getShape();
228  gpu::MMAMatrixType cType =
229  cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
230  ArrayRef<int64_t> cTypeShape = cType.getShape();
231  int64_t m = cTypeShape[0];
232  int64_t n = cTypeShape[1];
233  int64_t k = aTypeShape[1];
234  NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
235  ? NVVM::MMALayout::col
236  : NVVM::MMALayout::row;
237  NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
238  ? NVVM::MMALayout::col
239  : NVVM::MMALayout::row;
240  NVVM::MMATypes sourceType = getElementType(aType);
241  NVVM::MMATypes destType = getElementType(cType);
242  if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
243  destType) == 0)
244  return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
245 
246  NVVM::MMATypes bElementType = getElementType(
247  cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
248  if (bElementType != sourceType)
249  return rewriter.notifyMatchFailure(
250  op, "WMMA compute op input matrix element types must match.");
251 
252  unpackOp(adaptor.getOpA());
253  unpackOp(adaptor.getOpB());
254  unpackOp(adaptor.getOpC());
255 
256  rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
257  op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
258  destType, unpackedOps);
259  return success();
260  }
261 };
262 
263 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
264 struct WmmaConstantOpToNVVMLowering
265  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
267  gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
268 
269  LogicalResult
270  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
271  OpAdaptor adaptor,
272  ConversionPatternRewriter &rewriter) const override {
273  if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
274  adaptor.getOperands(), rewriter)))
275  return failure();
276  Location loc = subgroupMmaConstantOp.getLoc();
277  Value cst = adaptor.getOperands()[0];
279  cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
280  // If the element type is a vector create a vector from the operand.
281  if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
282  Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
283  for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
284  Value idx = rewriter.create<LLVM::ConstantOp>(
285  loc, rewriter.getI32Type(), vecEl);
286  vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
287  cst, idx);
288  }
289  cst = vecCst;
290  }
291  Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
292  for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
293  matrixStruct =
294  rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
295  }
296  rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
297  return success();
298  }
299 };
300 
301 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
302  Value rhs, bool isMin) {
303  auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
304  Type i1Type = builder.getI1Type();
305  if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
306  i1Type = VectorType::get(vecType.getShape(), i1Type);
307  Value cmp = builder.create<LLVM::FCmpOp>(
308  loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
309  lhs, rhs);
310  Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
311  Value isNan = builder.create<LLVM::FCmpOp>(
312  loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
313  Value nan = builder.create<LLVM::ConstantOp>(
314  loc, lhs.getType(),
315  builder.getFloatAttr(floatType,
316  APFloat::getQNaN(floatType.getFloatSemantics())));
317  return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
318 }
319 
320 static Value createScalarOp(OpBuilder &builder, Location loc,
321  gpu::MMAElementwiseOp op,
322  ArrayRef<Value> operands) {
323  switch (op) {
324  case gpu::MMAElementwiseOp::ADDF:
325  return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
326  case gpu::MMAElementwiseOp::MULF:
327  return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
328  case gpu::MMAElementwiseOp::DIVF:
329  return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
330  case gpu::MMAElementwiseOp::MAXF:
331  return createMinMaxF(builder, loc, operands[0], operands[1],
332  /*isMin=*/false);
333  case gpu::MMAElementwiseOp::MINF:
334  return createMinMaxF(builder, loc, operands[0], operands[1],
335  /*isMin=*/true);
336  default:
337  llvm_unreachable("unknown op");
338  }
339 }
340 
341 /// Convert GPU MMA elementwise ops to extract + op + insert.
342 struct WmmaElementwiseOpToNVVMLowering
343  : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
345  gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
346 
347  LogicalResult
348  matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
349  OpAdaptor adaptor,
350  ConversionPatternRewriter &rewriter) const override {
351  if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
352  adaptor.getOperands(), rewriter)))
353  return failure();
354  Location loc = subgroupMmaElementwiseOp.getLoc();
355  size_t numOperands = adaptor.getOperands().size();
357  cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
358  Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
359  for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
360  SmallVector<Value> extractedOperands;
361  for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
362  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
363  loc, adaptor.getOperands()[opIdx], i));
364  }
365  Value element =
366  createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
367  extractedOperands);
368  matrixStruct =
369  rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
370  }
371  rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
372  return success();
373  }
374 };
375 
376 } // namespace
377 
378 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
380  NVVM::MMAFrag frag = convertOperand(type.getOperand());
381  NVVM::MMATypes eltType = getElementType(type);
382  auto nRow = type.getShape()[0];
383  auto nCol = type.getShape()[1];
384  std::pair<Type, unsigned> typeInfo =
385  NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
387  type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
388 }
389 
391  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
392  patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
393  WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
394  WmmaElementwiseOpToNVVMLowering>(converter);
395 }
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:277
IntegerType getI32Type()
Definition: Builders.cpp:99
IntegerType getI1Type()
Definition: Builders.cpp:89
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:489
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:213
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:480
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:83
bool isF32() const
Definition: Types.cpp:55
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:95
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:62
bool isF16() const
Definition: Types.cpp:53
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:136
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:140
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:142
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.