MLIR  21.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // SPIRV Cooperative Matrix ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/ValueRange.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30 
31 #include <cassert>
32 
33 namespace mlir {
34 //===----------------------------------------------------------------------===//
35 // Patterns and helpers.
36 //===----------------------------------------------------------------------===//
37 
38 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39 /// when the elementwise op directly supports with cooperative matrix type.
40 /// Returns false if cannot.
41 ///
42 /// See SPV_KHR_cooperative_matrix for supported elementwise ops.
44  gpu::SubgroupMmaElementwiseOp op, Type coopType,
45  ValueRange operands) {
46  assert((isa<spirv::CooperativeMatrixType>(coopType)));
47 
48  switch (op.getOpType()) {
49  case gpu::MMAElementwiseOp::ADDF:
50  builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
51  return true;
53  builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
54  return true;
55  case gpu::MMAElementwiseOp::SUBF:
56  builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
57  return true;
59  builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
60  return true;
61  case gpu::MMAElementwiseOp::DIVF:
62  builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63  return true;
64  case gpu::MMAElementwiseOp::DIVS:
65  builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66  return true;
67  case gpu::MMAElementwiseOp::DIVU:
68  builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69  return true;
70  case gpu::MMAElementwiseOp::NEGATEF:
71  builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72  return true;
73  case gpu::MMAElementwiseOp::NEGATES:
74  builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75  return true;
76  case gpu::MMAElementwiseOp::EXTF:
77  builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78  return true;
79  default:
80  break;
81  }
82  return false;
83 }
84 
86  assert(!operands.empty());
87  if (!llvm::all_equal(
88  llvm::map_range(operands, [](Value v) { return v.getType(); })))
89  return false;
90 
91  return isa<spirv::CooperativeMatrixType>(operands.front().getType());
92 }
93 
94 namespace {
95 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96 /// matrix ops.
97 struct WmmaConstantOpToSPIRVLowering final
98  : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
100 
101  LogicalResult
102  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  Value cst = llvm::getSingleElement(adaptor.getOperands());
105  auto coopType = getTypeConverter()->convertType(op.getType());
106  if (!coopType)
107  return rewriter.notifyMatchFailure(op, "type conversion failed");
108 
109  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
110  return success();
111  }
112 };
113 
114 /// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115 /// matrix ops.
116 struct WmmaExtractOpToSPIRVLowering final
117  : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
119 
120  LogicalResult
121  matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123  Value matrix = adaptor.getMatrix();
124  auto coopType =
125  getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126  matrix.getType());
127  if (!coopType)
128  return rewriter.notifyMatchFailure(op, "type conversion failed");
129 
130  SmallVector<int32_t> intValues;
131  for (Value val : op.getIndices()) {
132  if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133  intValues.push_back(static_cast<int32_t>(constOp.value()));
134  } else {
135  return rewriter.notifyMatchFailure(op, "indices must be constants");
136  }
137  }
138 
139  Type elementType = coopType.getElementType();
140  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141  op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142  return success();
143  }
144 };
145 
146 /// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147 /// matrix ops.
148 struct WmmaInsertOpToSPIRVLowering final
149  : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
151 
152  LogicalResult
153  matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154  ConversionPatternRewriter &rewriter) const override {
155  Value value = adaptor.getValue();
156  Value matrix = adaptor.getMatrix();
157  auto coopType = getTypeConverter()->convertType(matrix.getType());
158  if (!coopType)
159  return rewriter.notifyMatchFailure(op, "type conversion failed");
160 
161  SmallVector<int32_t> intValues;
162  for (Value val : op.getIndices()) {
163  if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164  intValues.push_back(static_cast<int32_t>(constOp.value()));
165  } else {
166  return rewriter.notifyMatchFailure(op, "indices must be constants");
167  }
168  }
169 
170  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171  op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172  return success();
173  }
174 };
175 
176 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
177 /// the default case.
178 struct WmmaElementwiseOpToSPIRVDefaultLowering final
179  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
181 
182  LogicalResult
183  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const override {
185  // All operands should be of cooperative matrix types.
186  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
187  return rewriter.notifyMatchFailure(op,
188  "not all operands are coop matrices");
189  }
190 
191  auto coopType = getTypeConverter()->convertType(op.getType());
192  if (!coopType)
193  return rewriter.notifyMatchFailure(op, "type conversion failed");
194 
195  return success(
196  createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
197  }
198 };
199 
200 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
201 /// matrix times scalar case.
202 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
203  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
205 
206  LogicalResult
207  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override {
209  if (adaptor.getOperands().size() != 2)
210  return failure();
211 
212  // All operands should be of cooperative matrix types.
213  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
214  return rewriter.notifyMatchFailure(op,
215  "not all operands are coop matrices");
216  }
217 
218  if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
219  return failure();
220 
221  // Use the original operands to check whether one of the operands is a splat
222  // scalar value.
223  Value lhs = op.getOperands().front();
224  Value rhs = op.getOperands().back();
225  Value splat = nullptr;
226  Value matrix = nullptr;
227  if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228  splat = adaptor.getOperands().front();
229  matrix = adaptor.getOperands().back();
230  } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
231  matrix = adaptor.getOperands().front();
232  splat = adaptor.getOperands().back();
233  }
234  if (!splat || !matrix)
235  return rewriter.notifyMatchFailure(op, "no splat operand");
236 
237  // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
238  Value scalar;
239  auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
240  if (!cc) {
241  return rewriter.notifyMatchFailure(op,
242  "splat is not a composite construct");
243  }
244 
245  scalar = llvm::getSingleElement(cc.getConstituents());
246 
247  auto coopType = getTypeConverter()->convertType(op.getType());
248  if (!coopType)
249  return rewriter.notifyMatchFailure(op, "type conversion failed");
250  rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
251  op, coopType, ValueRange{matrix, scalar});
252  return success();
253  }
254 };
255 } // namespace
256 
257 //===----------------------------------------------------------------------===//
258 // SPV_KHR_cooperative_matrix
259 //===----------------------------------------------------------------------===//
260 
261 namespace khr {
262 namespace {
263 
264 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
265 /// dialect.
266 struct WmmaLoadOpToSPIRVLowering final
267  : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
269 
270  LogicalResult
271  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
272  ConversionPatternRewriter &rewriter) const override {
273  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
274  Location loc = op->getLoc();
275 
276  auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
277  MemRefType memrefType = op.getSrcMemref().getType();
278  Value bufferPtr =
279  spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
280  adaptor.getIndices(), loc, rewriter);
281 
282  auto coopType =
283  typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
284  if (!coopType)
285  return rewriter.notifyMatchFailure(op, "type conversion failed");
286 
287  int64_t stride = op.getLeadDimension().getSExtValue();
288  IntegerType i32Type = rewriter.getI32Type();
289  auto strideValue = rewriter.create<spirv::ConstantOp>(
290  loc, i32Type, IntegerAttr::get(i32Type, stride));
291 
292  bool isColMajor = op.getTranspose().value_or(false);
293  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
294  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
295 
296  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
297  op, coopType, bufferPtr, strideValue, layout);
298  return success();
299  }
300 };
301 
302 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
303 /// dialect.
304 struct WmmaStoreOpToSPIRVLowering final
305  : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
307 
308  LogicalResult
309  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
310  ConversionPatternRewriter &rewriter) const override {
311  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
312  Location loc = op->getLoc();
313 
314  auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
315  Value bufferPtr =
316  spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
317  adaptor.getIndices(), loc, rewriter);
318 
319  int64_t stride = op.getLeadDimension().getSExtValue();
320  IntegerType i32Type = rewriter.getI32Type();
321  auto strideValue = rewriter.create<spirv::ConstantOp>(
322  loc, i32Type, IntegerAttr::get(i32Type, stride));
323 
324  bool isColMajor = op.getTranspose().value_or(false);
325  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
326  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
327 
328  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
329  op, bufferPtr, adaptor.getSrc(), strideValue, layout);
330  return success();
331  }
332 };
333 
334 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
335 /// dialect.
336 struct WmmaMmaOpToSPIRVLowering final
337  : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
339 
340  LogicalResult
341  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
342  OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
345  subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
346  adaptor.getOpC());
347  return success();
348  }
349 };
350 
351 } // namespace
352 } // namespace khr
353 } // namespace mlir
354 
356  const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
357  using namespace mlir;
358  MLIRContext *context = patterns.getContext();
359  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
360  khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361  WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
362  WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
363  // Give the following patterns higher benefit to prevail over the default one.
364  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
365  /*benefit=*/2);
366 }
367 
369  mlir::SPIRVTypeConverter &typeConverter) {
370  typeConverter.addConversion([](gpu::MMAMatrixType type) {
371  ArrayRef<int64_t> retTypeShape = type.getShape();
372  Type elementType = type.getElementType();
373  auto use =
375  .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
376  .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
377  .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
378 
379  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
380  retTypeShape[1],
381  spirv::Scope::Subgroup, use);
382  });
383 }
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
IntegerType getI32Type()
Definition: Builders.cpp:63
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:140
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:144
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:146
@ Type
An inlay hint that for a type annotation.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...