MLIR  20.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // SPIRV Cooperative Matrix ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/ValueRange.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30 
31 #include <cassert>
32 
33 namespace mlir {
34 //===----------------------------------------------------------------------===//
35 // Patterns and helpers.
36 //===----------------------------------------------------------------------===//
37 
38 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39 /// when the elementwise op directly supports with cooperative matrix type.
40 /// Returns false if cannot.
41 ///
42 /// See SPV_KHR_cooperative_matrix for supported elementwise ops.
44  gpu::SubgroupMmaElementwiseOp op, Type coopType,
45  ValueRange operands) {
46  assert((isa<spirv::CooperativeMatrixType>(coopType)));
47 
48  switch (op.getOpType()) {
49  case gpu::MMAElementwiseOp::ADDF:
50  builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
51  return true;
53  builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
54  return true;
55  case gpu::MMAElementwiseOp::SUBF:
56  builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
57  return true;
59  builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
60  return true;
61  case gpu::MMAElementwiseOp::DIVF:
62  builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63  return true;
64  case gpu::MMAElementwiseOp::DIVS:
65  builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66  return true;
67  case gpu::MMAElementwiseOp::DIVU:
68  builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69  return true;
70  case gpu::MMAElementwiseOp::NEGATEF:
71  builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72  return true;
73  case gpu::MMAElementwiseOp::NEGATES:
74  builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75  return true;
76  case gpu::MMAElementwiseOp::EXTF:
77  builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78  return true;
79  default:
80  break;
81  }
82  return false;
83 }
84 
86  assert(!operands.empty());
87  if (!llvm::all_equal(
88  llvm::map_range(operands, [](Value v) { return v.getType(); })))
89  return false;
90 
91  return isa<spirv::CooperativeMatrixType>(operands.front().getType());
92 }
93 
94 namespace {
95 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96 /// matrix ops.
97 struct WmmaConstantOpToSPIRVLowering final
98  : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
100 
101  LogicalResult
102  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  assert(adaptor.getOperands().size() == 1);
105  Value cst = adaptor.getOperands().front();
106  auto coopType = getTypeConverter()->convertType(op.getType());
107  if (!coopType)
108  return rewriter.notifyMatchFailure(op, "type conversion failed");
109 
110  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
111  return success();
112  }
113 };
114 
115 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
116 /// the default case.
117 struct WmmaElementwiseOpToSPIRVDefaultLowering final
118  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
120 
121  LogicalResult
122  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
123  ConversionPatternRewriter &rewriter) const override {
124  // All operands should be of cooperative matrix types.
125  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
126  return rewriter.notifyMatchFailure(op,
127  "not all operands are coop matrices");
128  }
129 
130  auto coopType = getTypeConverter()->convertType(op.getType());
131  if (!coopType)
132  return rewriter.notifyMatchFailure(op, "type conversion failed");
133 
134  return success(
135  createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
136  }
137 };
138 
139 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
140 /// matrix times scalar case.
141 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
142  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
144 
145  LogicalResult
146  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
148  if (adaptor.getOperands().size() != 2)
149  return failure();
150 
151  // All operands should be of cooperative matrix types.
152  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
153  return rewriter.notifyMatchFailure(op,
154  "not all operands are coop matrices");
155  }
156 
157  if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
158  return failure();
159 
160  // Use the original operands to check whether one of the operands is a splat
161  // scalar value.
162  Value lhs = op.getOperands().front();
163  Value rhs = op.getOperands().back();
164  Value splat = nullptr;
165  Value matrix = nullptr;
166  if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
167  splat = adaptor.getOperands().front();
168  matrix = adaptor.getOperands().back();
169  } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
170  matrix = adaptor.getOperands().front();
171  splat = adaptor.getOperands().back();
172  }
173  if (!splat || !matrix)
174  return rewriter.notifyMatchFailure(op, "no splat operand");
175 
176  // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
177  Value scalar;
178  auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
179  if (!cc) {
180  return rewriter.notifyMatchFailure(op,
181  "splat is not a composite construct");
182  }
183 
184  assert(cc.getConstituents().size() == 1);
185  scalar = cc.getConstituents().front();
186 
187  auto coopType = getTypeConverter()->convertType(op.getType());
188  if (!coopType)
189  return rewriter.notifyMatchFailure(op, "type conversion failed");
190  rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
191  op, coopType, ValueRange{matrix, scalar});
192  return success();
193  }
194 };
195 } // namespace
196 
197 //===----------------------------------------------------------------------===//
198 // SPV_KHR_cooperative_matrix
199 //===----------------------------------------------------------------------===//
200 
201 namespace khr {
202 namespace {
203 
204 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
205 /// dialect.
206 struct WmmaLoadOpToSPIRVLowering final
207  : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
209 
210  LogicalResult
211  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
212  ConversionPatternRewriter &rewriter) const override {
213  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
214  Location loc = op->getLoc();
215 
216  auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
217  MemRefType memrefType = op.getSrcMemref().getType();
218  Value bufferPtr =
219  spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
220  adaptor.getIndices(), loc, rewriter);
221 
222  auto coopType =
223  typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
224  if (!coopType)
225  return rewriter.notifyMatchFailure(op, "type conversion failed");
226 
227  int64_t stride = op.getLeadDimension().getSExtValue();
228  IntegerType i32Type = rewriter.getI32Type();
229  auto strideValue = rewriter.create<spirv::ConstantOp>(
230  loc, i32Type, IntegerAttr::get(i32Type, stride));
231 
232  bool isColMajor = op.getTranspose().value_or(false);
233  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
234  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
235 
236  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
237  op, coopType, bufferPtr, strideValue, layout);
238  return success();
239  }
240 };
241 
242 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
243 /// dialect.
244 struct WmmaStoreOpToSPIRVLowering final
245  : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
247 
248  LogicalResult
249  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
250  ConversionPatternRewriter &rewriter) const override {
251  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
252  Location loc = op->getLoc();
253 
254  auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
255  Value bufferPtr =
256  spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
257  adaptor.getIndices(), loc, rewriter);
258 
259  int64_t stride = op.getLeadDimension().getSExtValue();
260  IntegerType i32Type = rewriter.getI32Type();
261  auto strideValue = rewriter.create<spirv::ConstantOp>(
262  loc, i32Type, IntegerAttr::get(i32Type, stride));
263 
264  bool isColMajor = op.getTranspose().value_or(false);
265  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
266  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
267 
268  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
269  op, bufferPtr, adaptor.getSrc(), strideValue, layout);
270  return success();
271  }
272 };
273 
274 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
275 /// dialect.
276 struct WmmaMmaOpToSPIRVLowering final
277  : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
279 
280  LogicalResult
281  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
282  OpAdaptor adaptor,
283  ConversionPatternRewriter &rewriter) const override {
284  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
285  subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
286  adaptor.getOpC());
287  return success();
288  }
289 };
290 
291 } // namespace
292 } // namespace khr
293 } // namespace mlir
294 
296  SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
297  using namespace mlir;
298  MLIRContext *context = patterns.getContext();
299  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
300  khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
301  WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
302  // Give the following patterns higher benefit to prevail over the default one.
303  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
304  /*benefit=*/2);
305 }
306 
308  mlir::SPIRVTypeConverter &typeConverter) {
309  typeConverter.addConversion([](gpu::MMAMatrixType type) {
310  ArrayRef<int64_t> retTypeShape = type.getShape();
311  Type elementType = type.getElementType();
312  auto use =
314  .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
315  .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
316  .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
317 
318  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
319  retTypeShape[1],
320  spirv::Scope::Subgroup, use);
321  });
322 }
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
IntegerType getI32Type()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:136
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:140
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:142
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...