MLIR 22.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// SPIRV Cooperative Matrix ops.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/ValueRange.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringSwitch.h"
27
28#include <cassert>
29
30namespace mlir {
31//===----------------------------------------------------------------------===//
32// Patterns and helpers.
33//===----------------------------------------------------------------------===//
34
35/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
36/// when the elementwise op directly supports with cooperative matrix type.
37/// Returns false if cannot.
38///
39/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
40static bool createElementwiseOp(ConversionPatternRewriter &builder,
41 gpu::SubgroupMmaElementwiseOp op, Type coopType,
42 ValueRange operands) {
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
44
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
47 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
48 return true;
49 case gpu::MMAElementwiseOp::ADDI:
50 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
51 return true;
52 case gpu::MMAElementwiseOp::SUBF:
53 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
54 return true;
55 case gpu::MMAElementwiseOp::SUBI:
56 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
57 return true;
58 case gpu::MMAElementwiseOp::MULF:
59 builder.replaceOpWithNewOp<spirv::FMulOp>(op, coopType, operands);
60 return true;
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63 return true;
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66 return true;
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69 return true;
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72 return true;
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75 return true;
76 case gpu::MMAElementwiseOp::EXTF:
77 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78 return true;
79 default:
80 break;
81 }
82 return false;
83}
84
86 assert(!operands.empty());
87 if (!llvm::all_equal(
88 llvm::map_range(operands, [](Value v) { return v.getType(); })))
89 return false;
90
91 return isa<spirv::CooperativeMatrixType>(operands.front().getType());
92}
93
94namespace {
95/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96/// matrix ops.
97struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
99 using Base::Base;
100
101 LogicalResult
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter) const override {
104 Value cst = llvm::getSingleElement(adaptor.getOperands());
105 auto coopType = getTypeConverter()->convertType(op.getType());
106 if (!coopType)
107 return rewriter.notifyMatchFailure(op, "type conversion failed");
108
109 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
110 return success();
111 }
112};
113
114/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115/// matrix ops.
116struct WmmaExtractOpToSPIRVLowering final
117 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
118 using Base::Base;
119
120 LogicalResult
121 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123 Value matrix = adaptor.getMatrix();
124 auto coopType =
125 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126 matrix.getType());
127 if (!coopType)
128 return rewriter.notifyMatchFailure(op, "type conversion failed");
129
130 SmallVector<int32_t> intValues;
131 for (Value val : op.getIndices()) {
132 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133 intValues.push_back(static_cast<int32_t>(constOp.value()));
134 } else {
135 return rewriter.notifyMatchFailure(op, "indices must be constants");
136 }
137 }
138
139 Type elementType = coopType.getElementType();
140 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142 return success();
143 }
144};
145
146/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147/// matrix ops.
148struct WmmaInsertOpToSPIRVLowering final
149 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
150 using Base::Base;
151
152 LogicalResult
153 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const override {
155 Value value = adaptor.getValue();
156 Value matrix = adaptor.getMatrix();
157 auto coopType = getTypeConverter()->convertType(matrix.getType());
158 if (!coopType)
159 return rewriter.notifyMatchFailure(op, "type conversion failed");
160
161 SmallVector<int32_t> intValues;
162 for (Value val : op.getIndices()) {
163 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164 intValues.push_back(static_cast<int32_t>(constOp.value()));
165 } else {
166 return rewriter.notifyMatchFailure(op, "indices must be constants");
167 }
168 }
169
170 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172 return success();
173 }
174};
175
176/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
177/// the default case.
178struct WmmaElementwiseOpToSPIRVDefaultLowering final
179 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
180 using Base::Base;
181
182 LogicalResult
183 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 // All operands should be of cooperative matrix types.
186 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
187 return rewriter.notifyMatchFailure(op,
188 "not all operands are coop matrices");
189 }
190
191 auto coopType = getTypeConverter()->convertType(op.getType());
192 if (!coopType)
193 return rewriter.notifyMatchFailure(op, "type conversion failed");
194
195 return success(
196 createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
197 }
198};
199
200/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
201/// matrix times scalar case.
202struct WmmaElementwiseOpToSPIRVScalarMulLowering final
203 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
204 using Base::Base;
205
206 LogicalResult
207 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 if (adaptor.getOperands().size() != 2)
210 return failure();
211
212 // All operands should be of cooperative matrix types.
213 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
214 return rewriter.notifyMatchFailure(op,
215 "not all operands are coop matrices");
216 }
217
218 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
219 return failure();
220
221 // Use the original operands to check whether one of the operands is a splat
222 // scalar value.
223 Value lhs = op.getOperands().front();
224 Value rhs = op.getOperands().back();
225 Value splat = nullptr;
226 Value matrix = nullptr;
227 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228 splat = adaptor.getOperands().front();
229 matrix = adaptor.getOperands().back();
230 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
231 matrix = adaptor.getOperands().front();
232 splat = adaptor.getOperands().back();
233 }
234 if (!splat || !matrix)
235 return rewriter.notifyMatchFailure(op, "no splat operand");
236
237 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
238 Value scalar;
239 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
240 if (!cc) {
241 return rewriter.notifyMatchFailure(op,
242 "splat is not a composite construct");
243 }
244
245 scalar = llvm::getSingleElement(cc.getConstituents());
246
247 auto coopType = getTypeConverter()->convertType(op.getType());
248 if (!coopType)
249 return rewriter.notifyMatchFailure(op, "type conversion failed");
250 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
251 op, coopType, ValueRange{matrix, scalar});
252 return success();
253 }
254};
255} // namespace
256
257//===----------------------------------------------------------------------===//
258// SPV_KHR_cooperative_matrix
259//===----------------------------------------------------------------------===//
260
261namespace khr {
262namespace {
263
264/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
265/// dialect.
266struct WmmaLoadOpToSPIRVLowering final
267 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
268 using Base::Base;
269
270 LogicalResult
271 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter) const override {
273 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
274 Location loc = op->getLoc();
275
276 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
277 MemRefType memrefType = op.getSrcMemref().getType();
278 Value bufferPtr =
279 spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
280 adaptor.getIndices(), loc, rewriter);
281
282 auto coopType =
283 typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
284 if (!coopType)
285 return rewriter.notifyMatchFailure(op, "type conversion failed");
286
287 int64_t stride = op.getLeadDimension().getSExtValue();
288 IntegerType i32Type = rewriter.getI32Type();
289 auto strideValue = spirv::ConstantOp::create(
290 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
291
292 bool isColMajor = op.getTranspose().value_or(false);
293 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
294 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
295
296 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
297 op, coopType, bufferPtr, strideValue, layout);
298 return success();
299 }
300};
301
302/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
303/// dialect.
304struct WmmaStoreOpToSPIRVLowering final
305 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
306 using Base::Base;
307
308 LogicalResult
309 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
310 ConversionPatternRewriter &rewriter) const override {
311 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
312 Location loc = op->getLoc();
313
314 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
315 Value bufferPtr =
316 spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
317 adaptor.getIndices(), loc, rewriter);
318
319 int64_t stride = op.getLeadDimension().getSExtValue();
320 IntegerType i32Type = rewriter.getI32Type();
321 auto strideValue = spirv::ConstantOp::create(
322 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
323
324 bool isColMajor = op.getTranspose().value_or(false);
325 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
326 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
327
328 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
329 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
330 return success();
331 }
332};
333
334/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
335/// dialect.
336struct WmmaMmaOpToSPIRVLowering final
337 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
338 using Base::Base;
339
340 LogicalResult
341 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
342 OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const override {
344 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
345 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
346 adaptor.getOpC());
347 return success();
348 }
349};
350
351} // namespace
352} // namespace khr
353} // namespace mlir
354
356 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
357 using namespace mlir;
358 MLIRContext *context = patterns.getContext();
359 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
360 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
362 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
363 // Give the following patterns higher benefit to prevail over the default one.
364 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
365 /*benefit=*/2);
366}
367
369 mlir::SPIRVTypeConverter &typeConverter) {
370 typeConverter.addConversion([](gpu::MMAMatrixType type) {
371 ArrayRef<int64_t> retTypeShape = type.getShape();
372 Type elementType = type.getElementType();
373 auto use =
375 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
376 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
377 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
378
379 return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
380 retTypeShape[1],
381 spirv::Scope::Subgroup, use);
382 });
383}
return success()
lhs
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...