MLIR  18.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1 //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // SPIRV Cooperative Matrix ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/ValueRange.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30 
31 #include <cassert>
32 
33 namespace mlir {
34 //===----------------------------------------------------------------------===//
35 // Patterns and helpers used by both the KHR and the NV lowering paths.
36 //===----------------------------------------------------------------------===//
37 
38 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39 /// when the elementwise op directly supports with cooperative matrix type.
40 /// Returns false if cannot.
41 ///
42 /// See SPV_NV_cooperative_matrix for supported elementwise ops.
44  gpu::SubgroupMmaElementwiseOp op, Type coopType,
45  ValueRange operands) {
46  assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
47  coopType)));
48 
49  switch (op.getOpType()) {
50  case gpu::MMAElementwiseOp::ADDF:
51  builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
52  return true;
54  builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
55  return true;
56  case gpu::MMAElementwiseOp::SUBF:
57  builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
58  return true;
60  builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
61  return true;
62  case gpu::MMAElementwiseOp::DIVF:
63  builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
64  return true;
65  case gpu::MMAElementwiseOp::DIVS:
66  builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
67  return true;
68  case gpu::MMAElementwiseOp::DIVU:
69  builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
70  return true;
71  case gpu::MMAElementwiseOp::NEGATEF:
72  builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
73  return true;
74  case gpu::MMAElementwiseOp::NEGATES:
75  builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
76  return true;
77  case gpu::MMAElementwiseOp::EXTF:
78  builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
79  return true;
80  default:
81  break;
82  }
83  return false;
84 }
85 
87  assert(!operands.empty());
88  if (!llvm::all_equal(
89  llvm::map_range(operands, [](Value v) { return v.getType(); })))
90  return false;
91 
92  return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
93  operands.front().getType());
94 }
95 
96 namespace {
97 /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
98 /// matrix ops.
99 struct WmmaConstantOpToSPIRVLowering final
100  : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
102 
103  LogicalResult
104  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
105  ConversionPatternRewriter &rewriter) const override {
106  assert(adaptor.getOperands().size() == 1);
107  Value cst = adaptor.getOperands().front();
108  auto coopType = getTypeConverter()->convertType(op.getType());
109  if (!coopType)
110  return rewriter.notifyMatchFailure(op, "type conversion failed");
111 
112  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
113  return success();
114  }
115 };
116 
117 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
118 /// the default case.
119 struct WmmaElementwiseOpToSPIRVDefaultLowering final
120  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
122 
123  LogicalResult
124  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
125  ConversionPatternRewriter &rewriter) const override {
126  // All operands should be of cooperative matrix types.
127  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
128  return rewriter.notifyMatchFailure(op,
129  "not all operands are coop matrices");
130  }
131 
132  auto coopType = getTypeConverter()->convertType(op.getType());
133  if (!coopType)
134  return rewriter.notifyMatchFailure(op, "type conversion failed");
135 
136  return success(
137  createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
138  }
139 };
140 
141 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
142 /// matrix times scalar case.
143 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
144  : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
146 
147  LogicalResult
148  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
149  ConversionPatternRewriter &rewriter) const override {
150  if (adaptor.getOperands().size() != 2)
151  return failure();
152 
153  // All operands should be of cooperative matrix types.
154  if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
155  return rewriter.notifyMatchFailure(op,
156  "not all operands are coop matrices");
157  }
158 
159  if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
160  return failure();
161 
162  // Use the original operands to check whether one of the operands is a splat
163  // scalar value.
164  Value lhs = op.getOperands().front();
165  Value rhs = op.getOperands().back();
166  Value splat = nullptr;
167  Value matrix = nullptr;
168  if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
169  splat = adaptor.getOperands().front();
170  matrix = adaptor.getOperands().back();
171  } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
172  matrix = adaptor.getOperands().front();
173  splat = adaptor.getOperands().back();
174  }
175  if (!splat || !matrix)
176  return rewriter.notifyMatchFailure(op, "no splat operand");
177 
178  // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
179  Value scalar;
180  auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
181  if (!cc) {
182  return rewriter.notifyMatchFailure(op,
183  "splat is not a composite construct");
184  }
185 
186  assert(cc.getConstituents().size() == 1);
187  scalar = cc.getConstituents().front();
188 
189  auto coopType = getTypeConverter()->convertType(op.getType());
190  if (!coopType)
191  return rewriter.notifyMatchFailure(op, "type conversion failed");
192  rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
193  op, coopType, ValueRange{matrix, scalar});
194  return success();
195  }
196 };
197 } // namespace
198 
199 //===----------------------------------------------------------------------===//
200 // SPV_KHR_cooperative_matrix
201 //===----------------------------------------------------------------------===//
202 
203 namespace khr {
204 namespace {
205 
206 /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
207 /// dialect.
208 struct WmmaLoadOpToSPIRVLowering final
209  : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
211 
213  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
214  ConversionPatternRewriter &rewriter) const override {
215  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
216  Location loc = op->getLoc();
217 
218  auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
219  MemRefType memrefType = op.getSrcMemref().getType();
220  Value bufferPtr =
221  spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
222  adaptor.getIndices(), loc, rewriter);
223 
224  auto coopType =
225  typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
226  if (!coopType)
227  return rewriter.notifyMatchFailure(op, "type conversion failed");
228 
229  int64_t stride = op.getLeadDimension().getSExtValue();
230  IntegerType i32Type = rewriter.getI32Type();
231  auto strideValue = rewriter.create<spirv::ConstantOp>(
232  loc, i32Type, IntegerAttr::get(i32Type, stride));
233 
234  bool isColMajor = op.getTranspose().value_or(false);
235  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
236  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
237 
238  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
239  op, coopType, bufferPtr, strideValue, layout);
240  return success();
241  }
242 };
243 
244 /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
245 /// dialect.
246 struct WmmaStoreOpToSPIRVLowering final
247  : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
249 
251  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
252  ConversionPatternRewriter &rewriter) const override {
253  const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
254  Location loc = op->getLoc();
255 
256  auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
257  Value bufferPtr =
258  spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
259  adaptor.getIndices(), loc, rewriter);
260 
261  int64_t stride = op.getLeadDimension().getSExtValue();
262  IntegerType i32Type = rewriter.getI32Type();
263  auto strideValue = rewriter.create<spirv::ConstantOp>(
264  loc, i32Type, IntegerAttr::get(i32Type, stride));
265 
266  bool isColMajor = op.getTranspose().value_or(false);
267  auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
268  : spirv::CooperativeMatrixLayoutKHR::RowMajor;
269 
270  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
271  op, bufferPtr, adaptor.getSrc(), strideValue, layout);
272  return success();
273  }
274 };
275 
276 /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
277 /// dialect.
278 struct WmmaMmaOpToSPIRVLowering final
279  : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
281 
283  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
284  OpAdaptor adaptor,
285  ConversionPatternRewriter &rewriter) const override {
286  rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
287  subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
288  adaptor.getOpC());
289  return success();
290  }
291 };
292 
293 } // namespace
294 } // namespace khr
295 
296 //===----------------------------------------------------------------------===//
297 // SPV_NV_cooperative_matrix
298 //===----------------------------------------------------------------------===//
299 
300 namespace nv {
301 namespace {
302 
303 /// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
304 /// dialect.
305 struct WmmaLoadOpToSPIRVLowering final
306  : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
308 
310  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
311  OpAdaptor adaptor,
312  ConversionPatternRewriter &rewriter) const override {
313  Location loc = subgroupMmaLoadMatrixOp->getLoc();
314  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
315 
316  gpu::MMAMatrixType retType =
317  cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
318  auto memrefType =
319  cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
320  Value bufferPtr =
321  spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
322  adaptor.getIndices(), loc, rewriter);
323  auto coopType =
324  typeConverter.convertType<spirv::CooperativeMatrixNVType>(retType);
325  if (!coopType)
326  return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
327  "type conversion failed");
328 
329  int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
330  auto i32Type = rewriter.getI32Type();
331  auto strideValue = rewriter.create<spirv::ConstantOp>(
332  loc, i32Type, IntegerAttr::get(i32Type, stride));
333  bool isColMajor = static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
334  auto columnMajor = rewriter.create<spirv::ConstantOp>(
335  loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor));
336  rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
337  subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
338  spirv::MemoryAccessAttr());
339  return success();
340  }
341 };
342 
343 /// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
344 /// dialect.
345 struct WmmaStoreOpToSPIRVLowering final
346  : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
348 
350  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
351  OpAdaptor adaptor,
352  ConversionPatternRewriter &rewriter) const override {
353  Location loc = subgroupMmaStoreMatrixOp->getLoc();
354  auto memrefType =
355  cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType());
356  Value bufferPtr = spirv::getElementPtr(
357  *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
358  adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
359  int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
360  auto i32Type = rewriter.getI32Type();
361  auto strideValue = rewriter.create<spirv::ConstantOp>(
362  loc, i32Type, IntegerAttr::get(i32Type, stride));
363  bool useColMajor =
364  static_cast<bool>(subgroupMmaStoreMatrixOp.getTranspose());
365  auto columnMajor = rewriter.create<spirv::ConstantOp>(
366  loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
367  rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
368  subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
369  columnMajor, spirv::MemoryAccessAttr());
370  return success();
371  }
372 };
373 
374 /// Converts GPU MMA Compute to
375 /// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
376 struct WmmaMmaOpToSPIRVLowering final
377  : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
379 
381  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
382  OpAdaptor adaptor,
383  ConversionPatternRewriter &rewriter) const override {
384  rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
385  subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
386  adaptor.getOpB(), adaptor.getOpC());
387  return success();
388  }
389 };
390 
391 } // namespace
392 } // namespace nv
393 } // namespace mlir
394 
396  SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
397  using namespace mlir;
398  MLIRContext *context = patterns.getContext();
399  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
400  khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
401  WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
402  // Give the following patterns higher benefit to prevail over the default one.
403  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
404  /*benefit=*/2);
405 }
406 
408  SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
409  using namespace mlir;
410  MLIRContext *context = patterns.getContext();
411  patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
412  nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
413  WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
414  // Give the following patterns higher benefit to prevail over the default one.
415  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
416  /*benefit=*/2);
417 }
418 
420  mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
421  if (useNVTypes) {
422  typeConverter.addConversion([](gpu::MMAMatrixType type) {
423  ArrayRef<int64_t> retTypeShape = type.getShape();
424  Type elementType = type.getElementType();
426  elementType, spirv::Scope::Subgroup, retTypeShape[0],
427  retTypeShape[1]);
428  });
429  return;
430  }
431 
432  typeConverter.addConversion([](gpu::MMAMatrixType type) {
433  ArrayRef<int64_t> retTypeShape = type.getShape();
434  Type elementType = type.getElementType();
435  auto use =
437  .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
438  .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
439  .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
440 
441  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
442  retTypeShape[1],
443  spirv::Scope::Subgroup, use);
444  });
445 }
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
IntegerType getI32Type()
Definition: Builders.cpp:83
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
IntegerType getI1Type()
Definition: Builders.cpp:73
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:129
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:135
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:139
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:141
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the NV Cooperative Mat...
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter, bool useNVTypes=false)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix type conversion to the type converter.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26