MLIR 22.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1//===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert GPU dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/IR/Matchers.h"
23#include <optional>
24
25using namespace mlir;
26
27static constexpr const char kSPIRVModule[] = "__spv__";
28
29namespace {
30/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
31/// builtin variables.
32template <typename SourceOp, spirv::BuiltIn builtin>
33class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
34public:
35 using OpConversionPattern<SourceOp>::OpConversionPattern;
36
37 LogicalResult
38 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override;
40};
41
42/// Pattern lowering subgroup size/id to loading SPIR-V invocation
43/// builtin variables.
44template <typename SourceOp, spirv::BuiltIn builtin>
45class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
46public:
47 using OpConversionPattern<SourceOp>::OpConversionPattern;
48
49 LogicalResult
50 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override;
52};
53
54/// This is separate because in Vulkan workgroup size is exposed to shaders via
55/// a constant with WorkgroupSize decoration. So here we cannot generate a
56/// builtin variable; instead the information in the `spirv.entry_point_abi`
57/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
58class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
59public:
60 WorkGroupSizeConversion(const TypeConverter &typeConverter,
61 MLIRContext *context)
62 : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
63
64 LogicalResult
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override;
67};
68
69/// Pattern to convert a kernel function in GPU dialect within a spirv.module.
70class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
71public:
72 using Base::Base;
73
74 LogicalResult
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter) const override;
77
78private:
79 SmallVector<int32_t, 3> workGroupSizeAsInt32;
80};
81
82/// Pattern to convert a gpu.module to a spirv.module.
83class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
84public:
85 using Base::Base;
86
87 LogicalResult
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override;
90};
91
92/// Pattern to convert a gpu.return into a SPIR-V return.
93// TODO: This can go to DRR when GPU return has operands.
94class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
95public:
96 using Base::Base;
97
98 LogicalResult
99 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const override;
101};
102
103/// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
104class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
105public:
106 using Base::Base;
107
108 LogicalResult
109 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override;
111};
112
113/// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
114class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
115public:
116 using Base::Base;
117
118 LogicalResult
119 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override;
121};
122
123/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
124class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
125public:
126 using Base::Base;
127
128 LogicalResult
129 matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override;
131};
132
133class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
134public:
135 using Base::Base;
136
137 LogicalResult
138 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter) const override;
140};
141
142} // namespace
143
144//===----------------------------------------------------------------------===//
145// Builtins.
146//===----------------------------------------------------------------------===//
147
148template <typename SourceOp, spirv::BuiltIn builtin>
149LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
150 SourceOp op, typename SourceOp::Adaptor adaptor,
151 ConversionPatternRewriter &rewriter) const {
152 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
153 Type indexType = typeConverter->getIndexType();
154
155 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
156 // type <3xi32> by the spec:
157 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
158 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
159 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
160 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
161 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
162 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
163 //
164 // For OpenCL, it depends on the Physical32/Physical64 addressing model:
165 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
166 bool forShader =
167 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
168 Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
169
170 Value vector =
171 spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
172 Value dim = spirv::CompositeExtractOp::create(
173 rewriter, op.getLoc(), builtinType, vector,
174 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
175 if (forShader && builtinType != indexType)
176 dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
177 rewriter.replaceOp(op, dim);
178 return success();
179}
180
181template <typename SourceOp, spirv::BuiltIn builtin>
182LogicalResult
183SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
184 SourceOp op, typename SourceOp::Adaptor adaptor,
185 ConversionPatternRewriter &rewriter) const {
186 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
187 Type indexType = typeConverter->getIndexType();
188 Type i32Type = rewriter.getIntegerType(32);
189
190 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
191 // type i32 by the spec:
192 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
193 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
194 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
195 //
196 // For OpenCL, they are also required to be i32:
197 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
198 Value builtinValue =
199 spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
200 if (i32Type != indexType)
201 builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
202 builtinValue);
203 rewriter.replaceOp(op, builtinValue);
204 return success();
205}
206
207LogicalResult WorkGroupSizeConversion::matchAndRewrite(
208 gpu::BlockDimOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter) const {
211 if (!workGroupSizeAttr)
212 return failure();
213
214 int val =
215 workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
216 auto convertedType =
217 getTypeConverter()->convertType(op.getResult().getType());
218 if (!convertedType)
219 return failure();
220 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
221 op, convertedType, IntegerAttr::get(convertedType, val));
222 return success();
223}
224
225//===----------------------------------------------------------------------===//
226// GPUFuncOp
227//===----------------------------------------------------------------------===//
228
229// Legalizes a GPU function as an entry SPIR-V function.
230static spirv::FuncOp
231lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
232 ConversionPatternRewriter &rewriter,
233 spirv::EntryPointABIAttr entryPointInfo,
235 auto fnType = funcOp.getFunctionType();
236 if (fnType.getNumResults()) {
237 funcOp.emitError("SPIR-V lowering only supports entry functions"
238 "with no return values right now");
239 return nullptr;
240 }
241 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
242 funcOp.emitError(
243 "lowering as entry functions requires ABI info for all arguments "
244 "or none of them");
245 return nullptr;
246 }
247 // Update the signature to valid SPIR-V types and add the ABI
248 // attributes. These will be "materialized" by using the
249 // LowerABIAttributesPass.
250 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
251 {
252 for (const auto &argType :
253 enumerate(funcOp.getFunctionType().getInputs())) {
254 auto convertedType = typeConverter.convertType(argType.value());
255 if (!convertedType)
256 return nullptr;
257 signatureConverter.addInputs(argType.index(), convertedType);
258 }
259 }
260 auto newFuncOp = spirv::FuncOp::create(
261 rewriter, funcOp.getLoc(), funcOp.getName(),
262 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
263 for (const auto &namedAttr : funcOp->getAttrs()) {
264 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
265 namedAttr.getName() == SymbolTable::getSymbolAttrName())
266 continue;
267 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
268 }
269
270 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
271 newFuncOp.end());
272 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
273 &signatureConverter)))
274 return nullptr;
275 rewriter.eraseOp(funcOp);
276
277 // Set the attributes for argument and the function.
278 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
279 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
280 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
281 }
282 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
283
284 return newFuncOp;
285}
286
287/// Populates `argABI` with spirv.interface_var_abi attributes for lowering
288/// gpu.func to spirv.func if no arguments have the attributes set
289/// already. Returns failure if any argument has the ABI attribute set already.
290static LogicalResult
291getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
293 if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
294 return success();
295
296 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
297 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
299 return failure();
300 // Vulkan's interface variable requirements needs scalars to be wrapped in a
301 // struct. The struct held in storage buffer.
302 std::optional<spirv::StorageClass> sc;
303 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
304 sc = spirv::StorageClass::StorageBuffer;
305 argABI.push_back(
306 spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
307 }
308 return success();
309}
310
311LogicalResult GPUFuncOpConversion::matchAndRewrite(
312 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
313 ConversionPatternRewriter &rewriter) const {
314 if (!gpu::GPUDialect::isKernel(funcOp))
315 return failure();
316
317 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
318 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
319 if (failed(
320 getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
321 argABI.clear();
322 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
323 // If the ABI is already specified, use it.
324 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
326 if (!abiAttr) {
327 funcOp.emitRemark(
328 "match failure: missing 'spirv.interface_var_abi' attribute at "
329 "argument ")
330 << argIndex;
331 return failure();
332 }
333 argABI.push_back(abiAttr);
334 }
335 }
336
337 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
338 if (!entryPointAttr) {
339 funcOp.emitRemark(
340 "match failure: missing 'spirv.entry_point_abi' attribute");
341 return failure();
342 }
343 spirv::FuncOp newFuncOp = lowerAsEntryFunction(
344 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
345 if (!newFuncOp)
346 return failure();
347 newFuncOp->removeAttr(
348 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
349 return success();
350}
351
352//===----------------------------------------------------------------------===//
353// ModuleOp with gpu.module.
354//===----------------------------------------------------------------------===//
355
356LogicalResult GPUModuleConversion::matchAndRewrite(
357 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const {
359 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
360 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
361 spirv::AddressingModel addressingModel = spirv::getAddressingModel(
362 targetEnv, typeConverter->getOptions().use64bitIndex);
363 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
364 if (failed(memoryModel))
365 return moduleOp.emitRemark(
366 "cannot deduce memory model from 'spirv.target_env'");
367
368 // Add a keyword to the module name to avoid symbolic conflict.
369 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
370 auto spvModule = spirv::ModuleOp::create(
371 rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
372 StringRef(spvModuleName));
373
374 // Move the region from the module op into the SPIR-V module.
375 Region &spvModuleRegion = spvModule.getRegion();
376 rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
377 spvModuleRegion.begin());
378 // The spirv.module build method adds a block. Remove that.
379 rewriter.eraseBlock(&spvModuleRegion.back());
380
381 // Some of the patterns call `lookupTargetEnv` during conversion and they
382 // will fail if called after GPUModuleConversion and we don't preserve
383 // `TargetEnv` attribute.
384 // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
385 if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
387 spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
388 if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
389 for (Attribute targetAttr : targets)
390 if (auto spirvTargetEnvAttr =
391 dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
392 spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
393 break;
394 }
395 }
396
397 rewriter.eraseOp(moduleOp);
398 return success();
399}
400
401//===----------------------------------------------------------------------===//
402// GPU return inside kernel functions to SPIR-V return.
403//===----------------------------------------------------------------------===//
404
405LogicalResult GPUReturnOpConversion::matchAndRewrite(
406 gpu::ReturnOp returnOp, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const {
408 if (!adaptor.getOperands().empty())
409 return failure();
410
411 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
412 return success();
413}
414
415//===----------------------------------------------------------------------===//
416// Barrier.
417//===----------------------------------------------------------------------===//
418
419LogicalResult GPUBarrierConversion::matchAndRewrite(
420 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const {
422 MLIRContext *context = getContext();
423 // Both execution and memory scope should be workgroup.
424 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
425 // Require acquire and release memory semantics for workgroup memory.
426 auto memorySemantics = spirv::MemorySemanticsAttr::get(
427 context, spirv::MemorySemantics::WorkgroupMemory |
428 spirv::MemorySemantics::AcquireRelease);
429 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
430 memorySemantics);
431 return success();
432}
433
434//===----------------------------------------------------------------------===//
435// Shuffle
436//===----------------------------------------------------------------------===//
437
438LogicalResult GPUShuffleConversion::matchAndRewrite(
439 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
440 ConversionPatternRewriter &rewriter) const {
441 // Require the shuffle width to be the same as the target's subgroup size,
442 // given that for SPIR-V non-uniform subgroup ops, we cannot select
443 // participating invocations.
444 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
445 unsigned subgroupSize =
446 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
447 IntegerAttr widthAttr;
448 if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
449 widthAttr.getValue().getZExtValue() != subgroupSize)
450 return rewriter.notifyMatchFailure(
451 shuffleOp, "shuffle width and target subgroup size mismatch");
452
453 assert(!adaptor.getOffset().getType().isSignedInteger() &&
454 "shuffle offset must be a signless/unsigned integer");
455
456 Location loc = shuffleOp.getLoc();
457 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
458 Value result;
459 Value validVal;
460
461 switch (shuffleOp.getMode()) {
462 case gpu::ShuffleMode::XOR: {
463 result = spirv::GroupNonUniformShuffleXorOp::create(
464 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
465 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
466 shuffleOp.getLoc(), rewriter);
467 break;
468 }
469 case gpu::ShuffleMode::IDX: {
470 result = spirv::GroupNonUniformShuffleOp::create(
471 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
472 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
473 shuffleOp.getLoc(), rewriter);
474 break;
475 }
476 case gpu::ShuffleMode::DOWN: {
477 result = spirv::GroupNonUniformShuffleDownOp::create(
478 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
479
480 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
481 Value resultLaneId =
482 arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
483 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
484 resultLaneId, adaptor.getWidth());
485 break;
486 }
487 case gpu::ShuffleMode::UP: {
488 result = spirv::GroupNonUniformShuffleUpOp::create(
489 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
490
491 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
492 Value resultLaneId =
493 arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
494 auto i32Type = rewriter.getIntegerType(32);
495 validVal = arith::CmpIOp::create(
496 rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
497 arith::ConstantOp::create(rewriter, loc, i32Type,
498 rewriter.getIntegerAttr(i32Type, 0)));
499 break;
500 }
501 }
502
503 rewriter.replaceOp(shuffleOp, {result, validVal});
504 return success();
505}
506
507//===----------------------------------------------------------------------===//
508// Rotate
509//===----------------------------------------------------------------------===//
510
511LogicalResult GPURotateConversion::matchAndRewrite(
512 gpu::RotateOp rotateOp, OpAdaptor adaptor,
513 ConversionPatternRewriter &rewriter) const {
514 const spirv::TargetEnv &targetEnv =
515 getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
516 unsigned subgroupSize =
517 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
518 unsigned width = rotateOp.getWidth();
519 if (width > subgroupSize)
520 return rewriter.notifyMatchFailure(
521 rotateOp, "rotate width is larger than target subgroup size");
522
523 Location loc = rotateOp.getLoc();
524 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
525 Value offsetVal =
526 arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
527 Value widthVal =
528 arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
529 Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
530 rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
531 Value validVal;
532 if (width == subgroupSize) {
533 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
534 } else {
535 IntegerAttr widthAttr = adaptor.getWidthAttr();
536 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
537 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
538 laneId, widthVal);
539 }
540
541 rewriter.replaceOp(rotateOp, {rotateResult, validVal});
542 return success();
543}
544
545//===----------------------------------------------------------------------===//
546// Group ops
547//===----------------------------------------------------------------------===//
548
549template <typename UniformOp, typename NonUniformOp>
551 Value arg, bool isGroup, bool isUniform,
552 std::optional<uint32_t> clusterSize) {
553 Type type = arg.getType();
554 auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
555 isGroup ? spirv::Scope::Workgroup
556 : spirv::Scope::Subgroup);
557 auto groupOp = spirv::GroupOperationAttr::get(
558 builder.getContext(), clusterSize.has_value()
559 ? spirv::GroupOperation::ClusteredReduce
560 : spirv::GroupOperation::Reduce);
561 if (isUniform) {
562 return UniformOp::create(builder, loc, type, scope, groupOp, arg)
563 .getResult();
564 }
565
566 Value clusterSizeValue;
567 if (clusterSize.has_value())
568 clusterSizeValue = spirv::ConstantOp::create(
569 builder, loc, builder.getI32Type(),
570 builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
571
572 return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
573 clusterSizeValue)
574 .getResult();
575}
576
577static std::optional<Value>
579 gpu::AllReduceOperation opType, bool isGroup,
580 bool isUniform, std::optional<uint32_t> clusterSize) {
581 enum class ElemType { Float, Boolean, Integer };
582 using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
583 std::optional<uint32_t>);
584 struct OpHandler {
585 gpu::AllReduceOperation kind;
586 ElemType elemType;
587 FuncT func;
588 };
589
590 Type type = arg.getType();
591 ElemType elementType;
592 if (isa<FloatType>(type)) {
593 elementType = ElemType::Float;
594 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
595 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
596 : ElemType::Integer;
597 } else {
598 return std::nullopt;
599 }
600
601 // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
602 // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
603 // reduction ops. We should account possible precision requirements in this
604 // conversion.
605
606 using ReduceType = gpu::AllReduceOperation;
607 const OpHandler handlers[] = {
608 {ReduceType::ADD, ElemType::Integer,
609 &createGroupReduceOpImpl<spirv::GroupIAddOp,
610 spirv::GroupNonUniformIAddOp>},
611 {ReduceType::ADD, ElemType::Float,
612 &createGroupReduceOpImpl<spirv::GroupFAddOp,
613 spirv::GroupNonUniformFAddOp>},
614 {ReduceType::MUL, ElemType::Integer,
615 &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
616 spirv::GroupNonUniformIMulOp>},
617 {ReduceType::MUL, ElemType::Float,
618 &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
619 spirv::GroupNonUniformFMulOp>},
620 {ReduceType::MINUI, ElemType::Integer,
621 &createGroupReduceOpImpl<spirv::GroupUMinOp,
622 spirv::GroupNonUniformUMinOp>},
623 {ReduceType::MINSI, ElemType::Integer,
624 &createGroupReduceOpImpl<spirv::GroupSMinOp,
625 spirv::GroupNonUniformSMinOp>},
626 {ReduceType::MINNUMF, ElemType::Float,
627 &createGroupReduceOpImpl<spirv::GroupFMinOp,
628 spirv::GroupNonUniformFMinOp>},
629 {ReduceType::MAXUI, ElemType::Integer,
630 &createGroupReduceOpImpl<spirv::GroupUMaxOp,
631 spirv::GroupNonUniformUMaxOp>},
632 {ReduceType::MAXSI, ElemType::Integer,
633 &createGroupReduceOpImpl<spirv::GroupSMaxOp,
634 spirv::GroupNonUniformSMaxOp>},
635 {ReduceType::MAXNUMF, ElemType::Float,
636 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
637 spirv::GroupNonUniformFMaxOp>},
638 {ReduceType::MINIMUMF, ElemType::Float,
639 &createGroupReduceOpImpl<spirv::GroupFMinOp,
640 spirv::GroupNonUniformFMinOp>},
641 {ReduceType::MAXIMUMF, ElemType::Float,
642 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
643 spirv::GroupNonUniformFMaxOp>}};
644
645 for (const OpHandler &handler : handlers)
646 if (handler.kind == opType && elementType == handler.elemType)
647 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
648
649 return std::nullopt;
650}
651
652/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
654 : public OpConversionPattern<gpu::AllReduceOp> {
655public:
656 using Base::Base;
657
658 LogicalResult
659 matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
660 ConversionPatternRewriter &rewriter) const override {
661 auto opType = op.getOp();
662
663 // gpu.all_reduce can have either reduction op attribute or reduction
664 // region. Only attribute version is supported.
665 if (!opType)
666 return failure();
667
668 auto result =
669 createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
670 /*isGroup*/ true, op.getUniform(), std::nullopt);
671 if (!result)
672 return failure();
673
674 rewriter.replaceOp(op, *result);
675 return success();
676 }
677};
678
679/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
681 : public OpConversionPattern<gpu::SubgroupReduceOp> {
682public:
683 using Base::Base;
684
685 LogicalResult
686 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter) const override {
688 if (op.getClusterStride() > 1) {
689 return rewriter.notifyMatchFailure(
690 op, "lowering for cluster stride > 1 is not implemented");
691 }
692
693 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
694 return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
695
697 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
698 /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
699 if (!result)
700 return failure();
701
702 rewriter.replaceOp(op, *result);
703 return success();
704 }
705};
706
707// Formulate a unique variable/constant name after
708// searching in the module for existing variable/constant names.
709// This is to avoid name collision with existing variables.
710// Example: printfMsg0, printfMsg1, printfMsg2, ...
711static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
712 std::string name;
713 unsigned number = 0;
714
715 do {
716 name.clear();
717 name = (prefix + llvm::Twine(number++)).str();
718 } while (moduleOp.lookupSymbol(name));
719
720 return name;
721}
722
723/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
724
725LogicalResult GPUPrintfConversion::matchAndRewrite(
726 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter) const {
728
729 Location loc = gpuPrintfOp.getLoc();
730
731 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
732 if (!moduleOp)
733 return failure();
734
735 // SPIR-V global variable is used to initialize printf
736 // format string value, if there are multiple printf messages,
737 // each global var needs to be created with a unique name.
738 std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
739 spirv::GlobalVariableOp globalVar;
740
741 IntegerType i8Type = rewriter.getI8Type();
742 IntegerType i32Type = rewriter.getI32Type();
743
744 // Each character of printf format string is
745 // stored as a spec constant. We need to create
746 // unique name for this spec constant like
747 // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
748 // for existing spec constant names.
749 auto createSpecConstant = [&](unsigned value) {
750 auto attr = rewriter.getI8IntegerAttr(value);
751 std::string specCstName =
752 makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
753
754 return spirv::SpecConstantOp::create(
755 rewriter, loc, rewriter.getStringAttr(specCstName), attr);
756 };
757 {
758 Operation *parent =
759 SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
760
761 ConversionPatternRewriter::InsertionGuard guard(rewriter);
762
763 Block &entryBlock = *parent->getRegion(0).begin();
764 rewriter.setInsertionPointToStart(
765 &entryBlock); // insertion point at module level
766
767 // Create Constituents with SpecConstant by scanning format string
768 // Each character of format string is stored as a spec constant
769 // and then these spec constants are used to create a
770 // SpecConstantCompositeOp.
771 llvm::SmallString<20> formatString(adaptor.getFormat());
772 formatString.push_back('\0'); // Null terminate for C.
773 SmallVector<Attribute, 4> constituents;
774 for (char c : formatString) {
775 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
776 constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
777 }
778
779 // Create SpecConstantCompositeOp to initialize the global variable
780 size_t contentSize = constituents.size();
781 auto globalType = spirv::ArrayType::get(i8Type, contentSize);
782 spirv::SpecConstantCompositeOp specCstComposite;
783 // There will be one SpecConstantCompositeOp per printf message/global var,
784 // so no need do lookup for existing ones.
785 std::string specCstCompositeName =
786 (llvm::Twine(globalVarName) + "_scc").str();
787
788 specCstComposite = spirv::SpecConstantCompositeOp::create(
789 rewriter, loc, TypeAttr::get(globalType),
790 rewriter.getStringAttr(specCstCompositeName),
791 rewriter.getArrayAttr(constituents));
792
793 auto ptrType = spirv::PointerType::get(
794 globalType, spirv::StorageClass::UniformConstant);
795
796 // Define a GlobalVarOp initialized using specialized constants
797 // that is used to specify the printf format string
798 // to be passed to the SPIRV CLPrintfOp.
799 globalVar = spirv::GlobalVariableOp::create(
800 rewriter, loc, ptrType, globalVarName,
801 FlatSymbolRefAttr::get(specCstComposite));
802
803 globalVar->setAttr("Constant", rewriter.getUnitAttr());
804 }
805 // Get SSA value of Global variable and create pointer to i8 to point to
806 // the format string.
807 Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
808 Value fmtStr = spirv::BitcastOp::create(
809 rewriter, loc,
810 spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
811 globalPtr);
812
813 // Get printf arguments.
814 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
815
816 spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
817
818 // Need to erase the gpu.printf op as gpu.printf does not use result vs
819 // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
820 // printf op.
821 rewriter.eraseOp(gpuPrintfOp);
822
823 return success();
824}
825
826//===----------------------------------------------------------------------===//
827// GPU To SPIRV Patterns.
828//===----------------------------------------------------------------------===//
829
832 patterns.add<
833 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
834 GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
835 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
836 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
837 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
838 LaunchConfigConversion<gpu::ThreadIdOp,
839 spirv::BuiltIn::LocalInvocationId>,
840 LaunchConfigConversion<gpu::GlobalIdOp,
841 spirv::BuiltIn::GlobalInvocationId>,
842 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
843 spirv::BuiltIn::SubgroupId>,
844 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
845 spirv::BuiltIn::NumSubgroups>,
846 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
847 spirv::BuiltIn::SubgroupSize>,
848 SingleDimLaunchConfigConversion<
849 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
850 WorkGroupSizeConversion, GPUAllReduceConversion,
851 GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
852 patterns.getContext());
853}
return success()
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static constexpr const char kSPIRVModule[]
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
ArrayAttr()
b getContext())
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getI32Type()
Definition Builders.cpp:63
MLIRContext * getContext() const
Definition Builders.h:56
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Block & back()
Definition Region.h:64
iterator begin()
Definition Region.h:55
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369