MLIR  21.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // SPIRV Cooperative Matrix ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/ValueRange.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30 
31 #include <cassert>
32 
33 namespace mlir {
34 //===----------------------------------------------------------------------===//
35 // Patterns and helpers.
36 //===----------------------------------------------------------------------===//
37 
38 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39 /// when the elementwise op directly supports with cooperative matrix type.
40 /// Returns false if cannot.
41 ///
42 /// See SPV_KHR_cooperative_matrix for supported elementwise ops.
44  gpu::SubgroupMmaElementwiseOp op, Type coopType,
45  ValueRange operands) {
46  assert((isa<spirv::CooperativeMatrixType>(coopType)));
47 
48  switch (op.getOpType()) {
49  case gpu::MMAElementwiseOp::ADDF:
50  builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
51  return true;
53  builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
54  return true;
55  case gpu::MMAElementwiseOp::SUBF:
56  builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
57  return true;
59  builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
60  return true;
61  case gpu::MMAElementwiseOp::DIVF:
62  builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63  return true;
64  case gpu::MMAElementwiseOp::DIVS:
65  builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66  return true;
67  case gpu::MMAElementwiseOp::DIVU:
68  builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69  return true;
70  case gpu::MMAElementwiseOp::NEGATEF:
71  builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72  return true;
73  case gpu::MMAElementwiseOp::NEGATES:
74  builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75  return true;
76  case gpu::MMAElementwiseOp::EXTF:
77  builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78  return true;
79  default:
80  break;
81  }
82  return false;
83 }
84 
86  assert(!operands.empty());
87  if (!llvm::all_equal(
88  llvm::map_range(operands, [](Value v) { return v.getType(); })))
89  return false;
90 
91  return isa<spirv::CooperativeMatrixType>(operands.front().getType());
92 }
93 
94 namespace {
95 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96 /// matrix ops.
97 struct WmmaConstantOpToSPIRVLowering final
98  : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
100 
101  LogicalResult
102  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  Value cst = llvm::getSingleElement(adaptor.getOperands());
105  auto coopType = getTypeConverter()->convertType(op.getType());
106  if (!coopType)
107  return rewriter.notifyMatchFailure(op, "type conversion failed");
108 
109  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
110  return success();
111  }
112 };
113 
114 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
115 /// the default case.
116 struct WmmaElementwiseOpToSPIRVDefaultLowering final
117  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
119 
120  LogicalResult
121  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123  // All operands should be of cooperative matrix types.
124  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
125  return rewriter.notifyMatchFailure(op,
126  "not all operands are coop matrices");
127  }
128 
129  auto coopType = getTypeConverter()->convertType(op.getType());
130  if (!coopType)
131  return rewriter.notifyMatchFailure(op, "type conversion failed");
132 
133  return success(
134  createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
135  }
136 };
137 
138 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
139 /// matrix times scalar case.
140 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
141  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
143 
144  LogicalResult
145  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
146  ConversionPatternRewriter &rewriter) const override {
147  if (adaptor.getOperands().size() != 2)
148  return failure();
149 
150  // All operands should be of cooperative matrix types.
151  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
152  return rewriter.notifyMatchFailure(op,
153  "not all operands are coop matrices");
154  }
155 
156  if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
157  return failure();
158 
159  // Use the original operands to check whether one of the operands is a splat
160  // scalar value.
161  Value lhs = op.getOperands().front();
162  Value rhs = op.getOperands().back();
163  Value splat = nullptr;
164  Value matrix = nullptr;
165  if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
166  splat = adaptor.getOperands().front();
167  matrix = adaptor.getOperands().back();
168  } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
169  matrix = adaptor.getOperands().front();
170  splat = adaptor.getOperands().back();
171  }
172  if (!splat || !matrix)
173  return rewriter.notifyMatchFailure(op, "no splat operand");
174 
175  // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
176  Value scalar;
177  auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
178  if (!cc) {
179  return rewriter.notifyMatchFailure(op,
180  "splat is not a composite construct");
181  }
182 
183  scalar = llvm::getSingleElement(cc.getConstituents());
184 
185  auto coopType = getTypeConverter()->convertType(op.getType());
186  if (!coopType)
187  return rewriter.notifyMatchFailure(op, "type conversion failed");
188  rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
189  op, coopType, ValueRange{matrix, scalar});
190  return success();
191  }
192 };
193 } // namespace
194 
195 //===----------------------------------------------------------------------===//
196 // SPV_KHR_cooperative_matrix
197 //===----------------------------------------------------------------------===//
198 
199 namespace khr {
200 namespace {
201 
202 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
203 /// dialect.
204 struct WmmaLoadOpToSPIRVLowering final
205  : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
207 
208  LogicalResult
209  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
210  ConversionPatternRewriter &rewriter) const override {
211  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
212  Location loc = op->getLoc();
213 
214  auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
215  MemRefType memrefType = op.getSrcMemref().getType();
216  Value bufferPtr =
217  spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
218  adaptor.getIndices(), loc, rewriter);
219 
220  auto coopType =
221  typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
222  if (!coopType)
223  return rewriter.notifyMatchFailure(op, "type conversion failed");
224 
225  int64_t stride = op.getLeadDimension().getSExtValue();
226  IntegerType i32Type = rewriter.getI32Type();
227  auto strideValue = rewriter.create<spirv::ConstantOp>(
228  loc, i32Type, IntegerAttr::get(i32Type, stride));
229 
230  bool isColMajor = op.getTranspose().value_or(false);
231  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
232  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
233 
234  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
235  op, coopType, bufferPtr, strideValue, layout);
236  return success();
237  }
238 };
239 
240 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
241 /// dialect.
242 struct WmmaStoreOpToSPIRVLowering final
243  : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
245 
246  LogicalResult
247  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
248  ConversionPatternRewriter &rewriter) const override {
249  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
250  Location loc = op->getLoc();
251 
252  auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
253  Value bufferPtr =
254  spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
255  adaptor.getIndices(), loc, rewriter);
256 
257  int64_t stride = op.getLeadDimension().getSExtValue();
258  IntegerType i32Type = rewriter.getI32Type();
259  auto strideValue = rewriter.create<spirv::ConstantOp>(
260  loc, i32Type, IntegerAttr::get(i32Type, stride));
261 
262  bool isColMajor = op.getTranspose().value_or(false);
263  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
264  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
265 
266  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
267  op, bufferPtr, adaptor.getSrc(), strideValue, layout);
268  return success();
269  }
270 };
271 
272 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
273 /// dialect.
274 struct WmmaMmaOpToSPIRVLowering final
275  : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
277 
278  LogicalResult
279  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
280  OpAdaptor adaptor,
281  ConversionPatternRewriter &rewriter) const override {
282  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
283  subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
284  adaptor.getOpC());
285  return success();
286  }
287 };
288 
289 } // namespace
290 } // namespace khr
291 } // namespace mlir
292 
294  const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
295  using namespace mlir;
296  MLIRContext *context = patterns.getContext();
297  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
298  khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
299  WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
300  // Give the following patterns higher benefit to prevail over the default one.
301  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
302  /*benefit=*/2);
303 }
304 
306  mlir::SPIRVTypeConverter &typeConverter) {
307  typeConverter.addConversion([](gpu::MMAMatrixType type) {
308  ArrayRef<int64_t> retTypeShape = type.getShape();
309  Type elementType = type.getElementType();
310  auto use =
312  .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
313  .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
314  .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
315 
316  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
317  retTypeShape[1],
318  spirv::Scope::Subgroup, use);
319  });
320 }
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
IntegerType getI32Type()
Definition: Builders.cpp:63
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:140
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:144
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:146
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...