MLIR 23.0.0git
WmmaOpsToSPIRV.cpp
Go to the documentation of this file.
1//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// SPIRV Cooperative Matrix ops.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/ValueRange.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringSwitch.h"
27
28#include <cassert>
29
30namespace mlir {
31//===----------------------------------------------------------------------===//
32// Patterns and helpers.
33//===----------------------------------------------------------------------===//
34
35/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
36/// when the elementwise op directly supports with cooperative matrix type.
37/// Returns false if cannot.
38///
39/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
40static bool createElementwiseOp(ConversionPatternRewriter &builder,
41 gpu::SubgroupMmaElementwiseOp op, Type coopType,
42 ValueRange operands) {
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
44
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
47 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
48 return true;
49 case gpu::MMAElementwiseOp::ADDI:
50 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
51 return true;
52 case gpu::MMAElementwiseOp::SUBF:
53 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
54 return true;
55 case gpu::MMAElementwiseOp::SUBI:
56 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
57 return true;
58 case gpu::MMAElementwiseOp::MULF:
59 builder.replaceOpWithNewOp<spirv::FMulOp>(op, coopType, operands);
60 return true;
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63 return true;
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66 return true;
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69 return true;
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72 return true;
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75 return true;
76 case gpu::MMAElementwiseOp::EXTF:
77 case gpu::MMAElementwiseOp::TRUNCF:
78 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
79 return true;
80 default:
81 break;
82 }
83 return false;
84}
85
87 assert(!operands.empty());
88 if (!llvm::all_equal(
89 llvm::map_range(operands, [](Value v) { return v.getType(); })))
90 return false;
91
92 return isa<spirv::CooperativeMatrixType>(operands.front().getType());
93}
94
96 auto elementType = dyn_cast<IntegerType>(type.getElementType());
97 return elementType && elementType.isSigned();
98}
99
100static spirv::CooperativeMatrixOperandsKHR
104 spirv::CooperativeMatrixType resultType) {
105 using Operands = spirv::CooperativeMatrixOperandsKHR;
106
107 Operands operands = Operands::None;
109 operands |= Operands::ASigned;
111 operands |= Operands::BSigned;
113 operands |= Operands::CSigned;
114 if (hasSignedIntegerElementType(resultType))
115 operands |= Operands::ResultSigned;
116 return operands;
117}
118
119namespace {
120/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
121/// matrix ops.
122struct WmmaConstantOpToSPIRVLowering final
123 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
124 using Base::Base;
125
126 LogicalResult
127 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const override {
129 Value cst = llvm::getSingleElement(adaptor.getOperands());
130 auto coopType = getTypeConverter()->convertType(op.getType());
131 if (!coopType)
132 return rewriter.notifyMatchFailure(op, "type conversion failed");
133
134 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
135 return success();
136 }
137};
138
139/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
140/// matrix ops.
141struct WmmaExtractOpToSPIRVLowering final
142 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
143 using Base::Base;
144
145 LogicalResult
146 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 Value matrix = adaptor.getMatrix();
149 auto coopType =
150 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
151 matrix.getType());
152 if (!coopType)
153 return rewriter.notifyMatchFailure(op, "type conversion failed");
154
155 SmallVector<int32_t> intValues;
156 for (Value val : op.getIndices()) {
157 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
158 intValues.push_back(static_cast<int32_t>(constOp.value()));
159 } else {
160 return rewriter.notifyMatchFailure(op, "indices must be constants");
161 }
162 }
163
164 Type elementType = coopType.getElementType();
165 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
166 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
167 return success();
168 }
169};
170
171/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
172/// matrix ops.
173struct WmmaInsertOpToSPIRVLowering final
174 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
175 using Base::Base;
176
177 LogicalResult
178 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 Value value = adaptor.getValue();
181 Value matrix = adaptor.getMatrix();
182 auto coopType = getTypeConverter()->convertType(matrix.getType());
183 if (!coopType)
184 return rewriter.notifyMatchFailure(op, "type conversion failed");
185
186 SmallVector<int32_t> intValues;
187 for (Value val : op.getIndices()) {
188 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
189 intValues.push_back(static_cast<int32_t>(constOp.value()));
190 } else {
191 return rewriter.notifyMatchFailure(op, "indices must be constants");
192 }
193 }
194
195 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
196 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
197 return success();
198 }
199};
200
201/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
202/// the default case.
203struct WmmaElementwiseOpToSPIRVDefaultLowering final
204 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
205 using Base::Base;
206
207 LogicalResult
208 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter) const override {
210 // All operands should be of cooperative matrix types.
211 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
212 return rewriter.notifyMatchFailure(op,
213 "not all operands are coop matrices");
214 }
215
216 auto coopType = getTypeConverter()->convertType(op.getType());
217 if (!coopType)
218 return rewriter.notifyMatchFailure(op, "type conversion failed");
219
220 return success(
221 createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
222 }
223};
224
225/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
226/// matrix times scalar case.
227struct WmmaElementwiseOpToSPIRVScalarMulLowering final
228 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
229 using Base::Base;
230
231 LogicalResult
232 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override {
234 if (adaptor.getOperands().size() != 2)
235 return failure();
236
237 // All operands should be of cooperative matrix types.
238 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
239 return rewriter.notifyMatchFailure(op,
240 "not all operands are coop matrices");
241 }
242
243 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
244 return failure();
245
246 // Use the original operands to check whether one of the operands is a splat
247 // scalar value.
248 Value lhs = op.getOperands().front();
249 Value rhs = op.getOperands().back();
250 Value splat = nullptr;
251 Value matrix = nullptr;
252 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
253 splat = adaptor.getOperands().front();
254 matrix = adaptor.getOperands().back();
255 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
256 matrix = adaptor.getOperands().front();
257 splat = adaptor.getOperands().back();
258 }
259 if (!splat || !matrix)
260 return rewriter.notifyMatchFailure(op, "no splat operand");
261
262 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
263 Value scalar;
264 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
265 if (!cc) {
266 return rewriter.notifyMatchFailure(op,
267 "splat is not a composite construct");
268 }
269
270 scalar = llvm::getSingleElement(cc.getConstituents());
271
272 auto coopType = getTypeConverter()->convertType(op.getType());
273 if (!coopType)
274 return rewriter.notifyMatchFailure(op, "type conversion failed");
275 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
276 op, coopType, ValueRange{matrix, scalar});
277 return success();
278 }
279};
280} // namespace
281
282//===----------------------------------------------------------------------===//
283// SPV_KHR_cooperative_matrix
284//===----------------------------------------------------------------------===//
285
286namespace khr {
287namespace {
288
289/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
290/// dialect.
291struct WmmaLoadOpToSPIRVLowering final
292 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
293 using Base::Base;
294
295 LogicalResult
296 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
299 Location loc = op->getLoc();
300
301 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
302 MemRefType memrefType = op.getSrcMemref().getType();
303 Value bufferPtr =
304 spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
305 adaptor.getIndices(), loc, rewriter);
306
307 auto coopType =
308 typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
309 if (!coopType)
310 return rewriter.notifyMatchFailure(op, "type conversion failed");
311
312 int64_t stride = op.getLeadDimension().getSExtValue();
313 IntegerType i32Type = rewriter.getI32Type();
314 auto strideValue = spirv::ConstantOp::create(
315 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
316
317 bool isColMajor = op.getTranspose().value_or(false);
318 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
319 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
320
321 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
322 op, coopType, bufferPtr, strideValue, layout);
323 return success();
324 }
325};
326
327/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
328/// dialect.
329struct WmmaStoreOpToSPIRVLowering final
330 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
331 using Base::Base;
332
333 LogicalResult
334 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
337 Location loc = op->getLoc();
338
339 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
340 Value bufferPtr =
341 spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
342 adaptor.getIndices(), loc, rewriter);
343
344 int64_t stride = op.getLeadDimension().getSExtValue();
345 IntegerType i32Type = rewriter.getI32Type();
346 auto strideValue = spirv::ConstantOp::create(
347 rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
348
349 bool isColMajor = op.getTranspose().value_or(false);
350 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
351 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
352
353 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
354 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
355 return success();
356 }
357};
358
359/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
360/// dialect.
361struct WmmaMmaOpToSPIRVLowering final
362 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
363 using Base::Base;
364
365 LogicalResult
366 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
367 OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const override {
369 auto aType =
370 dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpA().getType());
371 auto bType =
372 dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpB().getType());
373 auto cType =
374 dyn_cast<spirv::CooperativeMatrixType>(adaptor.getOpC().getType());
375 auto resultType =
376 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
377 subgroupMmaComputeOp.getResult().getType());
378 if (!aType || !bType || !cType || !resultType)
379 return rewriter.notifyMatchFailure(subgroupMmaComputeOp,
380 "type conversion failed");
381
382 using Operands = spirv::CooperativeMatrixOperandsKHR;
383 Operands operands =
384 getSignedCoopMatrixOperands(aType, bType, cType, resultType);
385 spirv::CooperativeMatrixOperandsKHRAttr operandsAttr;
386 if (operands != Operands::None)
387 operandsAttr = spirv::CooperativeMatrixOperandsKHRAttr::get(
388 rewriter.getContext(), operands);
389
390 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
391 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
392 adaptor.getOpC(), operandsAttr);
393 return success();
394 }
395};
396
397} // namespace
398} // namespace khr
399} // namespace mlir
400
402 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
403 using namespace mlir;
404 MLIRContext *context = patterns.getContext();
405 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
406 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
407 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
408 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
409 // Give the following patterns higher benefit to prevail over the default one.
410 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
411 /*benefit=*/2);
412}
413
415 mlir::SPIRVTypeConverter &typeConverter) {
416 typeConverter.addConversion([](gpu::MMAMatrixType type) {
417 ArrayRef<int64_t> retTypeShape = type.getShape();
418 Type elementType = type.getElementType();
419 auto use =
421 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
422 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
423 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
424
425 return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
426 retTypeShape[1],
427 spirv::Scope::Subgroup, use);
428 });
429}
return success()
lhs
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:139
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
static spirv::CooperativeMatrixOperandsKHR getSignedCoopMatrixOperands(spirv::CooperativeMatrixType aType, spirv::CooperativeMatrixType bType, spirv::CooperativeMatrixType cType, spirv::CooperativeMatrixType resultType)
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
static bool hasSignedIntegerElementType(spirv::CooperativeMatrixType type)