MLIR 23.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1//===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert GPU dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/IR/Matchers.h"
23#include <optional>
24
25using namespace mlir;
26
27static constexpr const char kSPIRVModule[] = "__spv__";
28
29namespace {
30/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
31/// builtin variables.
32template <typename SourceOp, spirv::BuiltIn builtin>
33class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
34public:
35 using OpConversionPattern<SourceOp>::OpConversionPattern;
36
37 LogicalResult
38 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override;
40};
41
42/// Pattern lowering subgroup size/id to loading SPIR-V invocation
43/// builtin variables.
44template <typename SourceOp, spirv::BuiltIn builtin>
45class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
46public:
47 using OpConversionPattern<SourceOp>::OpConversionPattern;
48
49 LogicalResult
50 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override;
52};
53
54/// This is separate because in Vulkan workgroup size is exposed to shaders via
55/// a constant with WorkgroupSize decoration. So here we cannot generate a
56/// builtin variable; instead the information in the `spirv.entry_point_abi`
57/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
58class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
59public:
60 WorkGroupSizeConversion(const TypeConverter &typeConverter,
61 MLIRContext *context)
62 : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
63
64 LogicalResult
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override;
67};
68
69/// Pattern to convert a kernel function in GPU dialect within a spirv.module.
70class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
71public:
72 using Base::Base;
73
74 LogicalResult
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter) const override;
77
78private:
79 SmallVector<int32_t, 3> workGroupSizeAsInt32;
80};
81
82/// Pattern to convert a gpu.module to a spirv.module.
83class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
84public:
85 using Base::Base;
86
87 LogicalResult
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override;
90};
91
92/// Pattern to convert a gpu.return into a SPIR-V return.
93// TODO: This can go to DRR when GPU return has operands.
94class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
95public:
96 using Base::Base;
97
98 LogicalResult
99 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const override;
101};
102
103/// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
104class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
105public:
106 using Base::Base;
107
108 LogicalResult
109 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override;
111};
112
113/// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
114class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
115public:
116 using Base::Base;
117
118 LogicalResult
119 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override;
121};
122
123/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
124class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
125public:
126 using Base::Base;
127
128 LogicalResult
129 matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override;
131};
132
133/// Pattern to convert a gpu.subgroup_broadcast op into a
134/// spirv.GroupNonUniformBroadcast op.
135class GPUSubgroupBroadcastConversion final
136 : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
137public:
138 using Base::Base;
139
140 LogicalResult
141 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter) const override;
143};
144
145class GPUBallotConversion final : public OpConversionPattern<gpu::BallotOp> {
146public:
147 using Base::Base;
148
149 LogicalResult
150 matchAndRewrite(gpu::BallotOp ballotOp, OpAdaptor adaptor,
151 ConversionPatternRewriter &rewriter) const override;
152};
153
154class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
155public:
156 using Base::Base;
157
158 LogicalResult
159 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
160 ConversionPatternRewriter &rewriter) const override;
161};
162
163} // namespace
164
165//===----------------------------------------------------------------------===//
166// Builtins.
167//===----------------------------------------------------------------------===//
168
169template <typename SourceOp, spirv::BuiltIn builtin>
170LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
171 SourceOp op, typename SourceOp::Adaptor adaptor,
172 ConversionPatternRewriter &rewriter) const {
173 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
174 Type indexType = typeConverter->getIndexType();
175
176 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
177 // type <3xi32> by the spec:
178 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
179 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
180 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
181 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
182 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
183 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
184 //
185 // For OpenCL, it depends on the Physical32/Physical64 addressing model:
186 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
187 bool forShader =
188 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
189 Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
190
191 Value vector =
192 spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
193 Value dim = spirv::CompositeExtractOp::create(
194 rewriter, op.getLoc(), builtinType, vector,
195 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
196 if (forShader && builtinType != indexType)
197 dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
198 rewriter.replaceOp(op, dim);
199 return success();
200}
201
202template <typename SourceOp, spirv::BuiltIn builtin>
203LogicalResult
204SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
205 SourceOp op, typename SourceOp::Adaptor adaptor,
206 ConversionPatternRewriter &rewriter) const {
207 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
208 Type indexType = typeConverter->getIndexType();
209 Type i32Type = rewriter.getIntegerType(32);
210
211 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
212 // type i32 by the spec:
213 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
214 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
215 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
216 //
217 // For OpenCL, they are also required to be i32:
218 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
219 Value builtinValue =
220 spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
221 if (i32Type != indexType)
222 builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
223 builtinValue);
224 rewriter.replaceOp(op, builtinValue);
225 return success();
226}
227
228LogicalResult WorkGroupSizeConversion::matchAndRewrite(
229 gpu::BlockDimOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const {
232 if (!workGroupSizeAttr)
233 return failure();
234
235 int val =
236 workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
237 auto convertedType =
238 getTypeConverter()->convertType(op.getResult().getType());
239 if (!convertedType)
240 return failure();
241 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
242 op, convertedType, IntegerAttr::get(convertedType, val));
243 return success();
244}
245
246//===----------------------------------------------------------------------===//
247// GPUFuncOp
248//===----------------------------------------------------------------------===//
249
250// Legalizes a GPU function as an entry SPIR-V function.
251static spirv::FuncOp
252lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
253 ConversionPatternRewriter &rewriter,
254 spirv::EntryPointABIAttr entryPointInfo,
256 auto fnType = funcOp.getFunctionType();
257 if (fnType.getNumResults()) {
258 funcOp.emitError("SPIR-V lowering only supports entry functions"
259 "with no return values right now");
260 return nullptr;
261 }
262 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
263 funcOp.emitError(
264 "lowering as entry functions requires ABI info for all arguments "
265 "or none of them");
266 return nullptr;
267 }
268 // Update the signature to valid SPIR-V types and add the ABI
269 // attributes. These will be "materialized" by using the
270 // LowerABIAttributesPass.
271 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
272 {
273 for (const auto &argType :
274 enumerate(funcOp.getFunctionType().getInputs())) {
275 auto convertedType = typeConverter.convertType(argType.value());
276 if (!convertedType)
277 return nullptr;
278 signatureConverter.addInputs(argType.index(), convertedType);
279 }
280 }
281 auto newFuncOp = spirv::FuncOp::create(
282 rewriter, funcOp.getLoc(), funcOp.getName(),
283 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
284 for (const auto &namedAttr : funcOp->getAttrs()) {
285 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
286 namedAttr.getName() == SymbolTable::getSymbolAttrName())
287 continue;
288 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
289 }
290
291 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
292 newFuncOp.end());
293 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
294 &signatureConverter)))
295 return nullptr;
296 rewriter.eraseOp(funcOp);
297
298 // Set the attributes for argument and the function.
299 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
300 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
301 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
302 }
303 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
304
305 return newFuncOp;
306}
307
308/// Populates `argABI` with spirv.interface_var_abi attributes for lowering
309/// gpu.func to spirv.func if no arguments have the attributes set
310/// already. Returns failure if any argument has the ABI attribute set already.
311static LogicalResult
312getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
314 if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
315 return success();
316
317 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
318 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
320 return failure();
321 // Vulkan's interface variable requirements needs scalars to be wrapped in a
322 // struct. The struct held in storage buffer.
323 std::optional<spirv::StorageClass> sc;
324 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
325 sc = spirv::StorageClass::StorageBuffer;
326 argABI.push_back(
327 spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
328 }
329 return success();
330}
331
332LogicalResult GPUFuncOpConversion::matchAndRewrite(
333 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
334 ConversionPatternRewriter &rewriter) const {
335 if (!gpu::GPUDialect::isKernel(funcOp))
336 return failure();
337
338 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
339 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
340 if (failed(
341 getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
342 argABI.clear();
343 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
344 // If the ABI is already specified, use it.
345 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
347 if (!abiAttr) {
348 funcOp.emitRemark(
349 "match failure: missing 'spirv.interface_var_abi' attribute at "
350 "argument ")
351 << argIndex;
352 return failure();
353 }
354 argABI.push_back(abiAttr);
355 }
356 }
357
358 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
359 if (!entryPointAttr) {
360 funcOp.emitRemark(
361 "match failure: missing 'spirv.entry_point_abi' attribute");
362 return failure();
363 }
364 spirv::FuncOp newFuncOp = lowerAsEntryFunction(
365 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
366 if (!newFuncOp)
367 return failure();
368 newFuncOp->removeAttr(
369 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
370 return success();
371}
372
373//===----------------------------------------------------------------------===//
374// ModuleOp with gpu.module.
375//===----------------------------------------------------------------------===//
376
377LogicalResult GPUModuleConversion::matchAndRewrite(
378 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
379 ConversionPatternRewriter &rewriter) const {
380 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
381 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
382 spirv::AddressingModel addressingModel = spirv::getAddressingModel(
383 targetEnv, typeConverter->getOptions().use64bitIndex);
384 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
385 if (failed(memoryModel))
386 return moduleOp.emitRemark(
387 "cannot deduce memory model from 'spirv.target_env'");
388
389 // Add a keyword to the module name to avoid symbolic conflict.
390 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
391 auto spvModule = spirv::ModuleOp::create(
392 rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
393 StringRef(spvModuleName));
394
395 // Move the region from the module op into the SPIR-V module.
396 Region &spvModuleRegion = spvModule.getRegion();
397 rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
398 spvModuleRegion.begin());
399 // The spirv.module build method adds a block. Remove that.
400 rewriter.eraseBlock(&spvModuleRegion.back());
401
402 // Some of the patterns call `lookupTargetEnv` during conversion and they
403 // will fail if called after GPUModuleConversion and we don't preserve
404 // `TargetEnv` attribute.
405 // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
406 if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
408 spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
409 if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
410 for (Attribute targetAttr : targets)
411 if (auto spirvTargetEnvAttr =
412 dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
413 spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
414 break;
415 }
416 }
417
418 rewriter.eraseOp(moduleOp);
419 return success();
420}
421
422//===----------------------------------------------------------------------===//
423// GPU return inside kernel functions to SPIR-V return.
424//===----------------------------------------------------------------------===//
425
426LogicalResult GPUReturnOpConversion::matchAndRewrite(
427 gpu::ReturnOp returnOp, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const {
429 if (!adaptor.getOperands().empty())
430 return failure();
431
432 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
433 return success();
434}
435
436//===----------------------------------------------------------------------===//
437// Barrier.
438//===----------------------------------------------------------------------===//
439
440LogicalResult GPUBarrierConversion::matchAndRewrite(
441 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
442 ConversionPatternRewriter &rewriter) const {
443 MLIRContext *context = getContext();
444 // Both execution and memory scope should be workgroup.
445 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
446 // Require acquire and release memory semantics for workgroup memory.
447 auto memorySemantics = spirv::MemorySemanticsAttr::get(
448 context, spirv::MemorySemantics::WorkgroupMemory |
449 spirv::MemorySemantics::AcquireRelease);
450 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
451 memorySemantics);
452 return success();
453}
454
455//===----------------------------------------------------------------------===//
456// Shuffle
457//===----------------------------------------------------------------------===//
458
459LogicalResult GPUShuffleConversion::matchAndRewrite(
460 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
461 ConversionPatternRewriter &rewriter) const {
462 // Require the shuffle width to be the same as the target's subgroup size,
463 // given that for SPIR-V non-uniform subgroup ops, we cannot select
464 // participating invocations.
465 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
466 unsigned subgroupSize =
467 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
468 IntegerAttr widthAttr;
469 if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
470 widthAttr.getValue().getZExtValue() != subgroupSize)
471 return rewriter.notifyMatchFailure(
472 shuffleOp, "shuffle width and target subgroup size mismatch");
473
474 assert(!adaptor.getOffset().getType().isSignedInteger() &&
475 "shuffle offset must be a signless/unsigned integer");
476
477 Location loc = shuffleOp.getLoc();
478 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
479 Value result;
480 Value validVal;
481
482 switch (shuffleOp.getMode()) {
483 case gpu::ShuffleMode::XOR: {
484 result = spirv::GroupNonUniformShuffleXorOp::create(
485 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
486 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
487 shuffleOp.getLoc(), rewriter);
488 break;
489 }
490 case gpu::ShuffleMode::IDX: {
491 result = spirv::GroupNonUniformShuffleOp::create(
492 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
493 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
494 shuffleOp.getLoc(), rewriter);
495 break;
496 }
497 case gpu::ShuffleMode::DOWN: {
498 result = spirv::GroupNonUniformShuffleDownOp::create(
499 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
500
501 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
502 Value resultLaneId =
503 arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
504 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
505 resultLaneId, adaptor.getWidth());
506 break;
507 }
508 case gpu::ShuffleMode::UP: {
509 result = spirv::GroupNonUniformShuffleUpOp::create(
510 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
511
512 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
513 Value resultLaneId =
514 arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
515 auto i32Type = rewriter.getIntegerType(32);
516 validVal = arith::CmpIOp::create(
517 rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
518 arith::ConstantOp::create(rewriter, loc, i32Type,
519 rewriter.getIntegerAttr(i32Type, 0)));
520 break;
521 }
522 }
523
524 rewriter.replaceOp(shuffleOp, {result, validVal});
525 return success();
526}
527
528//===----------------------------------------------------------------------===//
529// Rotate
530//===----------------------------------------------------------------------===//
531
532LogicalResult GPURotateConversion::matchAndRewrite(
533 gpu::RotateOp rotateOp, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter) const {
535 const spirv::TargetEnv &targetEnv =
536 getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
537 unsigned subgroupSize =
538 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
539 unsigned width = rotateOp.getWidth();
540 if (width > subgroupSize)
541 return rewriter.notifyMatchFailure(
542 rotateOp, "rotate width is larger than target subgroup size");
543
544 Location loc = rotateOp.getLoc();
545 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
546 Value offsetVal =
547 arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
548 Value widthVal =
549 arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
550 Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
551 rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
552 Value validVal;
553 if (width == subgroupSize) {
554 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
555 } else {
556 IntegerAttr widthAttr = adaptor.getWidthAttr();
557 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
558 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
559 laneId, widthVal);
560 }
561
562 rewriter.replaceOp(rotateOp, {rotateResult, validVal});
563 return success();
564}
565
566//===----------------------------------------------------------------------===//
567// Subgroup broadcast
568//===----------------------------------------------------------------------===//
569
570LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
571 gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
572 ConversionPatternRewriter &rewriter) const {
573 Location loc = op.getLoc();
574 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
575 Value result;
576
577 switch (op.getBroadcastType()) {
578 case gpu::BroadcastType::specific_lane:
579 result = spirv::GroupNonUniformBroadcastOp::create(
580 rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
581 break;
582 case gpu::BroadcastType::first_active_lane:
583 result = spirv::GroupNonUniformBroadcastFirstOp::create(
584 rewriter, loc, scope, adaptor.getSrc());
585 break;
586 }
587
588 rewriter.replaceOp(op, result);
589 return success();
590}
591
592LogicalResult GPUBallotConversion::matchAndRewrite(
593 gpu::BallotOp ballotOp, OpAdaptor adaptor,
594 ConversionPatternRewriter &rewriter) const {
595 Location loc = ballotOp.getLoc();
596 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
597 auto int32Type = rewriter.getI32Type();
598 auto vec4i32Type = VectorType::get({4}, int32Type);
599
600 // SPIR-V ballot returns vector<4xi32> to support subgroups up to 128 lanes.
601 Value ballot = spirv::GroupNonUniformBallotOp::create(
602 rewriter, loc, vec4i32Type, scope, adaptor.getPredicate());
603
604 auto intType = cast<IntegerType>(ballotOp.getType());
605 unsigned width = intType.getWidth();
606
607 if (width == 32) {
608 Value result =
609 spirv::CompositeExtractOp::create(rewriter, loc, ballot, {0});
610 rewriter.replaceOp(ballotOp, result);
611 } else if (width == 64) {
612 // Combine first two vector elements: low 32 bits + (high 32 bits << 32).
613 Value low = spirv::CompositeExtractOp::create(rewriter, loc, ballot, {0});
614 Value high = spirv::CompositeExtractOp::create(rewriter, loc, ballot, {1});
615
616 auto int64Type = rewriter.getI64Type();
617 Value lowExt = spirv::UConvertOp::create(rewriter, loc, int64Type, low);
618 Value highExt = spirv::UConvertOp::create(rewriter, loc, int64Type, high);
619
620 Value shift32 = spirv::ConstantOp::create(
621 rewriter, loc, int64Type, rewriter.getIntegerAttr(int64Type, 32));
622 Value highShifted =
623 spirv::ShiftLeftLogicalOp::create(rewriter, loc, highExt, shift32);
624
625 Value result =
626 spirv::BitwiseOrOp::create(rewriter, loc, lowExt, highShifted);
627 rewriter.replaceOp(ballotOp, result);
628 } else {
629 return rewriter.notifyMatchFailure(
630 ballotOp, "only i32 and i64 result types are supported for SPIR-V");
631 }
632
633 return success();
634}
635
636//===----------------------------------------------------------------------===//
637// Group ops
638//===----------------------------------------------------------------------===//
639
640template <typename UniformOp, typename NonUniformOp>
642 Value arg, bool isGroup, bool isUniform,
643 std::optional<uint32_t> clusterSize) {
644 Type type = arg.getType();
645 auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
646 isGroup ? spirv::Scope::Workgroup
647 : spirv::Scope::Subgroup);
648 auto groupOp = spirv::GroupOperationAttr::get(
649 builder.getContext(), clusterSize.has_value()
650 ? spirv::GroupOperation::ClusteredReduce
651 : spirv::GroupOperation::Reduce);
652 if (isUniform) {
653 return UniformOp::create(builder, loc, type, scope, groupOp, arg)
654 .getResult();
655 }
656
657 Value clusterSizeValue;
658 if (clusterSize.has_value())
659 clusterSizeValue = spirv::ConstantOp::create(
660 builder, loc, builder.getI32Type(),
661 builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
662
663 return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
664 clusterSizeValue)
665 .getResult();
666}
667
668static std::optional<Value>
670 gpu::AllReduceOperation opType, bool isGroup,
671 bool isUniform, std::optional<uint32_t> clusterSize) {
672 enum class ElemType { Float, Boolean, Integer };
673 using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
674 std::optional<uint32_t>);
675 struct OpHandler {
676 gpu::AllReduceOperation kind;
677 ElemType elemType;
678 FuncT func;
679 };
680
681 Type type = arg.getType();
682 ElemType elementType;
683 if (isa<FloatType>(type)) {
684 elementType = ElemType::Float;
685 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
686 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
687 : ElemType::Integer;
688 } else {
689 return std::nullopt;
690 }
691
692 // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
693 // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
694 // reduction ops. We should account possible precision requirements in this
695 // conversion.
696
697 using ReduceType = gpu::AllReduceOperation;
698 const OpHandler handlers[] = {
699 {ReduceType::ADD, ElemType::Integer,
700 &createGroupReduceOpImpl<spirv::GroupIAddOp,
701 spirv::GroupNonUniformIAddOp>},
702 {ReduceType::ADD, ElemType::Float,
703 &createGroupReduceOpImpl<spirv::GroupFAddOp,
704 spirv::GroupNonUniformFAddOp>},
705 {ReduceType::MUL, ElemType::Integer,
706 &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
707 spirv::GroupNonUniformIMulOp>},
708 {ReduceType::MUL, ElemType::Float,
709 &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
710 spirv::GroupNonUniformFMulOp>},
711 {ReduceType::MINUI, ElemType::Integer,
712 &createGroupReduceOpImpl<spirv::GroupUMinOp,
713 spirv::GroupNonUniformUMinOp>},
714 {ReduceType::MINSI, ElemType::Integer,
715 &createGroupReduceOpImpl<spirv::GroupSMinOp,
716 spirv::GroupNonUniformSMinOp>},
717 {ReduceType::MINNUMF, ElemType::Float,
718 &createGroupReduceOpImpl<spirv::GroupFMinOp,
719 spirv::GroupNonUniformFMinOp>},
720 {ReduceType::MAXUI, ElemType::Integer,
721 &createGroupReduceOpImpl<spirv::GroupUMaxOp,
722 spirv::GroupNonUniformUMaxOp>},
723 {ReduceType::MAXSI, ElemType::Integer,
724 &createGroupReduceOpImpl<spirv::GroupSMaxOp,
725 spirv::GroupNonUniformSMaxOp>},
726 {ReduceType::MAXNUMF, ElemType::Float,
727 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
728 spirv::GroupNonUniformFMaxOp>},
729 {ReduceType::MINIMUMF, ElemType::Float,
730 &createGroupReduceOpImpl<spirv::GroupFMinOp,
731 spirv::GroupNonUniformFMinOp>},
732 {ReduceType::MAXIMUMF, ElemType::Float,
733 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
734 spirv::GroupNonUniformFMaxOp>}};
735
736 for (const OpHandler &handler : handlers)
737 if (handler.kind == opType && elementType == handler.elemType)
738 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
739
740 return std::nullopt;
741}
742
743/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
745 : public OpConversionPattern<gpu::AllReduceOp> {
746public:
747 using Base::Base;
748
749 LogicalResult
750 matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter) const override {
752 auto opType = op.getOp();
753
754 // gpu.all_reduce can have either reduction op attribute or reduction
755 // region. Only attribute version is supported.
756 if (!opType)
757 return failure();
758
759 auto result =
760 createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
761 /*isGroup*/ true, op.getUniform(), std::nullopt);
762 if (!result)
763 return failure();
764
765 rewriter.replaceOp(op, *result);
766 return success();
767 }
768};
769
770/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
772 : public OpConversionPattern<gpu::SubgroupReduceOp> {
773public:
774 using Base::Base;
775
776 LogicalResult
777 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
778 ConversionPatternRewriter &rewriter) const override {
779 if (op.getClusterStride() > 1) {
780 return rewriter.notifyMatchFailure(
781 op, "lowering for cluster stride > 1 is not implemented");
782 }
783
784 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
785 return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
786
788 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
789 /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
790 if (!result)
791 return failure();
792
793 rewriter.replaceOp(op, *result);
794 return success();
795 }
796};
797
798// Formulate a unique variable/constant name after
799// searching in the module for existing variable/constant names.
800// This is to avoid name collision with existing variables.
801// Example: printfMsg0, printfMsg1, printfMsg2, ...
802static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
803 std::string name;
804 unsigned number = 0;
805
806 do {
807 name.clear();
808 name = (prefix + llvm::Twine(number++)).str();
809 } while (moduleOp.lookupSymbol(name));
810
811 return name;
812}
813
814/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
815
816LogicalResult GPUPrintfConversion::matchAndRewrite(
817 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819
820 Location loc = gpuPrintfOp.getLoc();
821
822 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
823 if (!moduleOp)
824 return failure();
825
826 // SPIR-V global variable is used to initialize printf
827 // format string value, if there are multiple printf messages,
828 // each global var needs to be created with a unique name.
829 std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
830 spirv::GlobalVariableOp globalVar;
831
832 IntegerType i8Type = rewriter.getI8Type();
833 IntegerType i32Type = rewriter.getI32Type();
834
835 // Each character of printf format string is
836 // stored as a spec constant. We need to create
837 // unique name for this spec constant like
838 // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
839 // for existing spec constant names.
840 auto createSpecConstant = [&](unsigned value) {
841 auto attr = rewriter.getI8IntegerAttr(value);
842 std::string specCstName =
843 makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
844
845 return spirv::SpecConstantOp::create(
846 rewriter, loc, rewriter.getStringAttr(specCstName), attr);
847 };
848 {
849 Operation *parent =
850 SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
851
852 ConversionPatternRewriter::InsertionGuard guard(rewriter);
853
854 Block &entryBlock = *parent->getRegion(0).begin();
855 rewriter.setInsertionPointToStart(
856 &entryBlock); // insertion point at module level
857
858 // Create Constituents with SpecConstant by scanning format string
859 // Each character of format string is stored as a spec constant
860 // and then these spec constants are used to create a
861 // SpecConstantCompositeOp.
862 llvm::SmallString<20> formatString(adaptor.getFormat());
863 formatString.push_back('\0'); // Null terminate for C.
864 SmallVector<Attribute, 4> constituents;
865 for (char c : formatString) {
866 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
867 constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
868 }
869
870 // Create SpecConstantCompositeOp to initialize the global variable
871 size_t contentSize = constituents.size();
872 auto globalType = spirv::ArrayType::get(i8Type, contentSize);
873 spirv::SpecConstantCompositeOp specCstComposite;
874 // There will be one SpecConstantCompositeOp per printf message/global var,
875 // so no need do lookup for existing ones.
876 std::string specCstCompositeName =
877 (llvm::Twine(globalVarName) + "_scc").str();
878
879 specCstComposite = spirv::SpecConstantCompositeOp::create(
880 rewriter, loc, TypeAttr::get(globalType),
881 rewriter.getStringAttr(specCstCompositeName),
882 rewriter.getArrayAttr(constituents));
883
884 auto ptrType = spirv::PointerType::get(
885 globalType, spirv::StorageClass::UniformConstant);
886
887 // Define a GlobalVarOp initialized using specialized constants
888 // that is used to specify the printf format string
889 // to be passed to the SPIRV CLPrintfOp.
890 globalVar = spirv::GlobalVariableOp::create(
891 rewriter, loc, ptrType, globalVarName,
892 FlatSymbolRefAttr::get(specCstComposite));
893
894 globalVar->setAttr("Constant", rewriter.getUnitAttr());
895 }
896 // Get SSA value of Global variable and create pointer to i8 to point to
897 // the format string.
898 Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
899 Value fmtStr = spirv::BitcastOp::create(
900 rewriter, loc,
901 spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
902 globalPtr);
903
904 // Get printf arguments.
905 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
906
907 spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
908
909 // Need to erase the gpu.printf op as gpu.printf does not use result vs
910 // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
911 // printf op.
912 rewriter.eraseOp(gpuPrintfOp);
913
914 return success();
915}
916
917//===----------------------------------------------------------------------===//
918// GPU To SPIRV Patterns.
919//===----------------------------------------------------------------------===//
920
922 RewritePatternSet &patterns) {
923 patterns.add<
924 GPUBarrierConversion, GPUBallotConversion, GPUFuncOpConversion,
925 GPUModuleConversion, GPUReturnOpConversion, GPUShuffleConversion,
926 GPURotateConversion, GPUSubgroupBroadcastConversion,
927 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
928 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
929 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
930 LaunchConfigConversion<gpu::ThreadIdOp,
931 spirv::BuiltIn::LocalInvocationId>,
932 LaunchConfigConversion<gpu::GlobalIdOp,
933 spirv::BuiltIn::GlobalInvocationId>,
934 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
935 spirv::BuiltIn::SubgroupId>,
936 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
937 spirv::BuiltIn::NumSubgroups>,
938 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
939 spirv::BuiltIn::SubgroupSize>,
940 SingleDimLaunchConfigConversion<
941 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
942 WorkGroupSizeConversion, GPUAllReduceConversion,
943 GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
944 patterns.getContext());
945}
return success()
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static constexpr const char kSPIRVModule[]
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
ArrayAttr()
b getContext())
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getI32Type()
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Block & back()
Definition Region.h:64
iterator begin()
Definition Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369