MLIR  21.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // SPIRV Cooperative Matrix ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/IR/ValueRange.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 
28 #include <cassert>
29 
30 namespace mlir {
31 //===----------------------------------------------------------------------===//
32 // Patterns and helpers.
33 //===----------------------------------------------------------------------===//
34 
35 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
36 /// when the elementwise op directly supports with cooperative matrix type.
37 /// Returns false if cannot.
38 ///
39 /// See SPV_KHR_cooperative_matrix for supported elementwise ops.
41  gpu::SubgroupMmaElementwiseOp op, Type coopType,
42  ValueRange operands) {
43  assert((isa<spirv::CooperativeMatrixType>(coopType)));
44 
45  switch (op.getOpType()) {
46  case gpu::MMAElementwiseOp::ADDF:
47  builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
48  return true;
50  builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
51  return true;
52  case gpu::MMAElementwiseOp::SUBF:
53  builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
54  return true;
56  builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
57  return true;
58  case gpu::MMAElementwiseOp::DIVF:
59  builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
60  return true;
61  case gpu::MMAElementwiseOp::DIVS:
62  builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
63  return true;
64  case gpu::MMAElementwiseOp::DIVU:
65  builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
66  return true;
67  case gpu::MMAElementwiseOp::NEGATEF:
68  builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
69  return true;
70  case gpu::MMAElementwiseOp::NEGATES:
71  builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
72  return true;
73  case gpu::MMAElementwiseOp::EXTF:
74  builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
75  return true;
76  default:
77  break;
78  }
79  return false;
80 }
81 
83  assert(!operands.empty());
84  if (!llvm::all_equal(
85  llvm::map_range(operands, [](Value v) { return v.getType(); })))
86  return false;
87 
88  return isa<spirv::CooperativeMatrixType>(operands.front().getType());
89 }
90 
91 namespace {
92 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
93 /// matrix ops.
94 struct WmmaConstantOpToSPIRVLowering final
95  : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
97 
98  LogicalResult
99  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
100  ConversionPatternRewriter &rewriter) const override {
101  Value cst = llvm::getSingleElement(adaptor.getOperands());
102  auto coopType = getTypeConverter()->convertType(op.getType());
103  if (!coopType)
104  return rewriter.notifyMatchFailure(op, "type conversion failed");
105 
106  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
107  return success();
108  }
109 };
110 
111 /// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
112 /// matrix ops.
113 struct WmmaExtractOpToSPIRVLowering final
114  : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
116 
117  LogicalResult
118  matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
119  ConversionPatternRewriter &rewriter) const override {
120  Value matrix = adaptor.getMatrix();
121  auto coopType =
122  getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
123  matrix.getType());
124  if (!coopType)
125  return rewriter.notifyMatchFailure(op, "type conversion failed");
126 
127  SmallVector<int32_t> intValues;
128  for (Value val : op.getIndices()) {
129  if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
130  intValues.push_back(static_cast<int32_t>(constOp.value()));
131  } else {
132  return rewriter.notifyMatchFailure(op, "indices must be constants");
133  }
134  }
135 
136  Type elementType = coopType.getElementType();
137  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
138  op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
139  return success();
140  }
141 };
142 
143 /// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
144 /// matrix ops.
145 struct WmmaInsertOpToSPIRVLowering final
146  : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
148 
149  LogicalResult
150  matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
151  ConversionPatternRewriter &rewriter) const override {
152  Value value = adaptor.getValue();
153  Value matrix = adaptor.getMatrix();
154  auto coopType = getTypeConverter()->convertType(matrix.getType());
155  if (!coopType)
156  return rewriter.notifyMatchFailure(op, "type conversion failed");
157 
158  SmallVector<int32_t> intValues;
159  for (Value val : op.getIndices()) {
160  if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
161  intValues.push_back(static_cast<int32_t>(constOp.value()));
162  } else {
163  return rewriter.notifyMatchFailure(op, "indices must be constants");
164  }
165  }
166 
167  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
168  op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
169  return success();
170  }
171 };
172 
173 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
174 /// the default case.
175 struct WmmaElementwiseOpToSPIRVDefaultLowering final
176  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
178 
179  LogicalResult
180  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override {
182  // All operands should be of cooperative matrix types.
183  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
184  return rewriter.notifyMatchFailure(op,
185  "not all operands are coop matrices");
186  }
187 
188  auto coopType = getTypeConverter()->convertType(op.getType());
189  if (!coopType)
190  return rewriter.notifyMatchFailure(op, "type conversion failed");
191 
192  return success(
193  createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
194  }
195 };
196 
197 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
198 /// matrix times scalar case.
199 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
200  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
202 
203  LogicalResult
204  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
205  ConversionPatternRewriter &rewriter) const override {
206  if (adaptor.getOperands().size() != 2)
207  return failure();
208 
209  // All operands should be of cooperative matrix types.
210  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
211  return rewriter.notifyMatchFailure(op,
212  "not all operands are coop matrices");
213  }
214 
215  if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
216  return failure();
217 
218  // Use the original operands to check whether one of the operands is a splat
219  // scalar value.
220  Value lhs = op.getOperands().front();
221  Value rhs = op.getOperands().back();
222  Value splat = nullptr;
223  Value matrix = nullptr;
224  if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
225  splat = adaptor.getOperands().front();
226  matrix = adaptor.getOperands().back();
227  } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228  matrix = adaptor.getOperands().front();
229  splat = adaptor.getOperands().back();
230  }
231  if (!splat || !matrix)
232  return rewriter.notifyMatchFailure(op, "no splat operand");
233 
234  // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
235  Value scalar;
236  auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
237  if (!cc) {
238  return rewriter.notifyMatchFailure(op,
239  "splat is not a composite construct");
240  }
241 
242  scalar = llvm::getSingleElement(cc.getConstituents());
243 
244  auto coopType = getTypeConverter()->convertType(op.getType());
245  if (!coopType)
246  return rewriter.notifyMatchFailure(op, "type conversion failed");
247  rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
248  op, coopType, ValueRange{matrix, scalar});
249  return success();
250  }
251 };
252 } // namespace
253 
254 //===----------------------------------------------------------------------===//
255 // SPV_KHR_cooperative_matrix
256 //===----------------------------------------------------------------------===//
257 
258 namespace khr {
259 namespace {
260 
261 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
262 /// dialect.
263 struct WmmaLoadOpToSPIRVLowering final
264  : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
266 
267  LogicalResult
268  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
269  ConversionPatternRewriter &rewriter) const override {
270  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
271  Location loc = op->getLoc();
272 
273  auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
274  MemRefType memrefType = op.getSrcMemref().getType();
275  Value bufferPtr =
276  spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
277  adaptor.getIndices(), loc, rewriter);
278 
279  auto coopType =
280  typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
281  if (!coopType)
282  return rewriter.notifyMatchFailure(op, "type conversion failed");
283 
284  int64_t stride = op.getLeadDimension().getSExtValue();
285  IntegerType i32Type = rewriter.getI32Type();
286  auto strideValue = rewriter.create<spirv::ConstantOp>(
287  loc, i32Type, IntegerAttr::get(i32Type, stride));
288 
289  bool isColMajor = op.getTranspose().value_or(false);
290  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
291  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
292 
293  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
294  op, coopType, bufferPtr, strideValue, layout);
295  return success();
296  }
297 };
298 
299 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
300 /// dialect.
301 struct WmmaStoreOpToSPIRVLowering final
302  : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
304 
305  LogicalResult
306  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
307  ConversionPatternRewriter &rewriter) const override {
308  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
309  Location loc = op->getLoc();
310 
311  auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
312  Value bufferPtr =
313  spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
314  adaptor.getIndices(), loc, rewriter);
315 
316  int64_t stride = op.getLeadDimension().getSExtValue();
317  IntegerType i32Type = rewriter.getI32Type();
318  auto strideValue = rewriter.create<spirv::ConstantOp>(
319  loc, i32Type, IntegerAttr::get(i32Type, stride));
320 
321  bool isColMajor = op.getTranspose().value_or(false);
322  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
323  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
324 
325  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
326  op, bufferPtr, adaptor.getSrc(), strideValue, layout);
327  return success();
328  }
329 };
330 
331 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
332 /// dialect.
333 struct WmmaMmaOpToSPIRVLowering final
334  : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
336 
337  LogicalResult
338  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
339  OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const override {
341  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
342  subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
343  adaptor.getOpC());
344  return success();
345  }
346 };
347 
348 } // namespace
349 } // namespace khr
350 } // namespace mlir
351 
353  const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
354  using namespace mlir;
355  MLIRContext *context = patterns.getContext();
356  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
357  khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
358  WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
359  WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
360  // Give the following patterns higher benefit to prevail over the default one.
361  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
362  /*benefit=*/2);
363 }
364 
366  mlir::SPIRVTypeConverter &typeConverter) {
367  typeConverter.addConversion([](gpu::MMAMatrixType type) {
368  ArrayRef<int64_t> retTypeShape = type.getShape();
369  Type elementType = type.getElementType();
370  auto use =
372  .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
373  .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
374  .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
375 
376  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
377  retTypeShape[1],
378  spirv::Scope::Subgroup, use);
379  });
380 }
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
IntegerType getI32Type()
Definition: Builders.cpp:62
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:140
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:144
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:146
@ Type
An inlay hint that for a type annotation.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...