MLIR 23.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// SPIRV Cooperative Matrix ops.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/ValueRange.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringSwitch.h"
27
28#include <cassert>
29
30namespace mlir {
31//===----------------------------------------------------------------------===//
32// Patterns and helpers.
33//===----------------------------------------------------------------------===//
34
35/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
36/// when the elementwise op directly supports with cooperative matrix type.
37/// Returns false if cannot.
38///
39/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
40static bool createElementwiseOp(ConversionPatternRewriter &builder,
41 gpu::SubgroupMmaElementwiseOp op, Type coopType,
42 ValueRange operands) {
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
44
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
47 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
48 return true;
49 case gpu::MMAElementwiseOp::ADDI:
50 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
51 return true;
52 case gpu::MMAElementwiseOp::SUBF:
53 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
54 return true;
55 case gpu::MMAElementwiseOp::SUBI:
56 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
57 return true;
58 case gpu::MMAElementwiseOp::MULF:
59 builder.replaceOpWithNewOp<spirv::FMulOp>(op, coopType, operands);
60 return true;
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63 return true;
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66 return true;
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69 return true;
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72 return true;
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75 return true;
76 case gpu::MMAElementwiseOp::EXTF:
77 case gpu::MMAElementwiseOp::TRUNCF:
78 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
79 return true;
80 default:
81 break;
82 }
83 return false;
84}
85
87 assert(!operands.empty());
88 if (!llvm::all_equal(
89 llvm::map_range(operands, [](Value v) { return v.getType(); })))
90 return false;
91
92 return isa<spirv::CooperativeMatrixType>(operands.front().getType());
93}
94
95namespace {
96/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
97/// matrix ops.
98struct WmmaConstantOpToSPIRVLowering final
99 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
100 using Base::Base;
101
102 LogicalResult
103 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 Value cst = llvm::getSingleElement(adaptor.getOperands());
106 auto coopType = getTypeConverter()->convertType(op.getType());
107 if (!coopType)
108 return rewriter.notifyMatchFailure(op, "type conversion failed");
109
110 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
111 return success();
112 }
113};
114
115/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
116/// matrix ops.
117struct WmmaExtractOpToSPIRVLowering final
118 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
119 using Base::Base;
120
121 LogicalResult
122 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter) const override {
124 Value matrix = adaptor.getMatrix();
125 auto coopType =
126 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
127 matrix.getType());
128 if (!coopType)
129 return rewriter.notifyMatchFailure(op, "type conversion failed");
130
131 SmallVector<int32_t> intValues;
132 for (Value val : op.getIndices()) {
133 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
134 intValues.push_back(static_cast<int32_t>(constOp.value()));
135 } else {
136 return rewriter.notifyMatchFailure(op, "indices must be constants");
137 }
138 }
139
140 Type elementType = coopType.getElementType();
141 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
142 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
143 return success();
144 }
145};
146
147/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
148/// matrix ops.
149struct WmmaInsertOpToSPIRVLowering final
150 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
151 using Base::Base;
152
153 LogicalResult
154 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter) const override {
156 Value value = adaptor.getValue();
157 Value matrix = adaptor.getMatrix();
158 auto coopType = getTypeConverter()->convertType(matrix.getType());
159 if (!coopType)
160 return rewriter.notifyMatchFailure(op, "type conversion failed");
161
162 SmallVector<int32_t> intValues;
163 for (Value val : op.getIndices()) {
164 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
165 intValues.push_back(static_cast<int32_t>(constOp.value()));
166 } else {
167 return rewriter.notifyMatchFailure(op, "indices must be constants");
168 }
169 }
170
171 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
172 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
173 return success();
174 }
175};
176
177/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
178/// the default case.
179struct WmmaElementwiseOpToSPIRVDefaultLowering final
180 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
181 using Base::Base;
182
183 LogicalResult
184 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override {
186 // All operands should be of cooperative matrix types.
187 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
188 return rewriter.notifyMatchFailure(op,
189 "not all operands are coop matrices");
190 }
191
192 auto coopType = getTypeConverter()->convertType(op.getType());
193 if (!coopType)
194 return rewriter.notifyMatchFailure(op, "type conversion failed");
195
196 return success(
197 createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
198 }
199};
200
201/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
202/// matrix times scalar case.
203struct WmmaElementwiseOpToSPIRVScalarMulLowering final
204 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
205 using Base::Base;
206
207 LogicalResult
208 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter) const override {
210 if (adaptor.getOperands().size() != 2)
211 return failure();
212
213 // All operands should be of cooperative matrix types.
214 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
215 return rewriter.notifyMatchFailure(op,
216 "not all operands are coop matrices");
217 }
218
219 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
220 return failure();
221
222 // Use the original operands to check whether one of the operands is a splat
223 // scalar value.
224 Value lhs = op.getOperands().front();
225 Value rhs = op.getOperands().back();
226 Value splat = nullptr;
227 Value matrix = nullptr;
228 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
229 splat = adaptor.getOperands().front();
230 matrix = adaptor.getOperands().back();
231 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
232 matrix = adaptor.getOperands().front();
233 splat = adaptor.getOperands().back();
234 }
235 if (!splat || !matrix)
236 return rewriter.notifyMatchFailure(op, "no splat operand");
237
238 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
239 Value scalar;
240 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
241 if (!cc) {
242 return rewriter.notifyMatchFailure(op,
243 "splat is not a composite construct");
244 }
245
246 scalar = llvm::getSingleElement(cc.getConstituents());
247
248 auto coopType = getTypeConverter()->convertType(op.getType());
249 if (!coopType)
250 return rewriter.notifyMatchFailure(op, "type conversion failed");
251 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
252 op, coopType, ValueRange{matrix, scalar});
253 return success();
254 }
255};
256} // namespace
257
258//===----------------------------------------------------------------------===//
259// SPV_KHR_cooperative_matrix
260//===----------------------------------------------------------------------===//
261
262namespace khr {
263namespace {
264
265/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
266/// dialect.
267struct WmmaLoadOpToSPIRVLowering final
268 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
269 using Base::Base;
270
271 LogicalResult
272 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
275 Location loc = op->getLoc();
276
277 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
278 MemRefType memrefType = op.getSrcMemref().getType();
279 Value bufferPtr =
280 spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
281 adaptor.getIndices(), loc, rewriter);
282
283 auto coopType =
284 typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
285 if (!coopType)
286 return rewriter.notifyMatchFailure(op, "type conversion failed");
287
288 int64_t stride = op.getLeadDimension().getSExtValue();
289 IntegerType i32Type = rewriter.getI32Type();
290 auto strideValue = spirv::ConstantOp::create(
291 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
292
293 bool isColMajor = op.getTranspose().value_or(false);
294 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
295 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
296
297 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
298 op, coopType, bufferPtr, strideValue, layout);
299 return success();
300 }
301};
302
303/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
304/// dialect.
305struct WmmaStoreOpToSPIRVLowering final
306 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
307 using Base::Base;
308
309 LogicalResult
310 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const override {
312 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
313 Location loc = op->getLoc();
314
315 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
316 Value bufferPtr =
317 spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
318 adaptor.getIndices(), loc, rewriter);
319
320 int64_t stride = op.getLeadDimension().getSExtValue();
321 IntegerType i32Type = rewriter.getI32Type();
322 auto strideValue = spirv::ConstantOp::create(
323 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
324
325 bool isColMajor = op.getTranspose().value_or(false);
326 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
327 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
328
329 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
330 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
331 return success();
332 }
333};
334
335/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
336/// dialect.
337struct WmmaMmaOpToSPIRVLowering final
338 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
339 using Base::Base;
340
341 LogicalResult
342 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
343 OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter) const override {
345 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
346 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
347 adaptor.getOpC());
348 return success();
349 }
350};
351
352} // namespace
353} // namespace khr
354} // namespace mlir
355
357 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
358 using namespace mlir;
359 MLIRContext *context = patterns.getContext();
360 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
361 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
362 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
363 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
364 // Give the following patterns higher benefit to prevail over the default one.
365 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
366 /*benefit=*/2);
367}
368
370 mlir::SPIRVTypeConverter &typeConverter) {
371 typeConverter.addConversion([](gpu::MMAMatrixType type) {
372 ArrayRef<int64_t> retTypeShape = type.getShape();
373 Type elementType = type.getElementType();
374 auto use =
376 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
377 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
378 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
379
380 return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
381 retTypeShape[1],
382 spirv::Scope::Subgroup, use);
383 });
384}
return success()
lhs
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...