MLIR  22.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/Matchers.h"
23 #include <optional>
24 
25 using namespace mlir;
26 
27 static constexpr const char kSPIRVModule[] = "__spv__";
28 
29 namespace {
30 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
31 /// builtin variables.
32 template <typename SourceOp, spirv::BuiltIn builtin>
33 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
34 public:
36 
37  LogicalResult
38  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
39  ConversionPatternRewriter &rewriter) const override;
40 };
41 
42 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
43 /// builtin variables.
44 template <typename SourceOp, spirv::BuiltIn builtin>
45 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
46 public:
48 
49  LogicalResult
50  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
51  ConversionPatternRewriter &rewriter) const override;
52 };
53 
54 /// This is separate because in Vulkan workgroup size is exposed to shaders via
55 /// a constant with WorkgroupSize decoration. So here we cannot generate a
56 /// builtin variable; instead the information in the `spirv.entry_point_abi`
57 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
58 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
59 public:
60  WorkGroupSizeConversion(const TypeConverter &typeConverter,
61  MLIRContext *context)
62  : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
63 
64  LogicalResult
65  matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
66  ConversionPatternRewriter &rewriter) const override;
67 };
68 
69 /// Pattern to convert a kernel function in GPU dialect within a spirv.module.
70 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
71 public:
73 
74  LogicalResult
75  matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
76  ConversionPatternRewriter &rewriter) const override;
77 
78 private:
79  SmallVector<int32_t, 3> workGroupSizeAsInt32;
80 };
81 
82 /// Pattern to convert a gpu.module to a spirv.module.
83 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
84 public:
86 
87  LogicalResult
88  matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override;
90 };
91 
92 /// Pattern to convert a gpu.return into a SPIR-V return.
93 // TODO: This can go to DRR when GPU return has operands.
94 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
95 public:
97 
98  LogicalResult
99  matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
100  ConversionPatternRewriter &rewriter) const override;
101 };
102 
103 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
104 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
105 public:
107 
108  LogicalResult
109  matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override;
111 };
112 
113 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
114 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
115 public:
117 
118  LogicalResult
119  matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override;
121 };
122 
123 /// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
124 class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
125 public:
127 
128  LogicalResult
129  matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override;
131 };
132 
133 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
134 public:
136 
137  LogicalResult
138  matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
139  ConversionPatternRewriter &rewriter) const override;
140 };
141 
142 } // namespace
143 
144 //===----------------------------------------------------------------------===//
145 // Builtins.
146 //===----------------------------------------------------------------------===//
147 
148 template <typename SourceOp, spirv::BuiltIn builtin>
149 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
150  SourceOp op, typename SourceOp::Adaptor adaptor,
151  ConversionPatternRewriter &rewriter) const {
152  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
153  Type indexType = typeConverter->getIndexType();
154 
155  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
156  // type <3xi32> by the spec:
157  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
158  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
159  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
160  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
161  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
162  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
163  //
164  // For OpenCL, it depends on the Physical32/Physical64 addressing model:
165  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
166  bool forShader =
167  typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
168  Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
169 
170  Value vector =
171  spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
172  Value dim = spirv::CompositeExtractOp::create(
173  rewriter, op.getLoc(), builtinType, vector,
174  rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
175  if (forShader && builtinType != indexType)
176  dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
177  rewriter.replaceOp(op, dim);
178  return success();
179 }
180 
181 template <typename SourceOp, spirv::BuiltIn builtin>
182 LogicalResult
183 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
184  SourceOp op, typename SourceOp::Adaptor adaptor,
185  ConversionPatternRewriter &rewriter) const {
186  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
187  Type indexType = typeConverter->getIndexType();
188  Type i32Type = rewriter.getIntegerType(32);
189 
190  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
191  // type i32 by the spec:
192  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
193  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
194  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
195  //
196  // For OpenCL, they are also required to be i32:
197  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
198  Value builtinValue =
199  spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
200  if (i32Type != indexType)
201  builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
202  builtinValue);
203  rewriter.replaceOp(op, builtinValue);
204  return success();
205 }
206 
207 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
208  gpu::BlockDimOp op, OpAdaptor adaptor,
209  ConversionPatternRewriter &rewriter) const {
210  DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
211  if (!workGroupSizeAttr)
212  return failure();
213 
214  int val =
215  workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
216  auto convertedType =
217  getTypeConverter()->convertType(op.getResult().getType());
218  if (!convertedType)
219  return failure();
220  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
221  op, convertedType, IntegerAttr::get(convertedType, val));
222  return success();
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // GPUFuncOp
227 //===----------------------------------------------------------------------===//
228 
229 // Legalizes a GPU function as an entry SPIR-V function.
230 static spirv::FuncOp
231 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
232  ConversionPatternRewriter &rewriter,
233  spirv::EntryPointABIAttr entryPointInfo,
235  auto fnType = funcOp.getFunctionType();
236  if (fnType.getNumResults()) {
237  funcOp.emitError("SPIR-V lowering only supports entry functions"
238  "with no return values right now");
239  return nullptr;
240  }
241  if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
242  funcOp.emitError(
243  "lowering as entry functions requires ABI info for all arguments "
244  "or none of them");
245  return nullptr;
246  }
247  // Update the signature to valid SPIR-V types and add the ABI
248  // attributes. These will be "materialized" by using the
249  // LowerABIAttributesPass.
250  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
251  {
252  for (const auto &argType :
253  enumerate(funcOp.getFunctionType().getInputs())) {
254  auto convertedType = typeConverter.convertType(argType.value());
255  if (!convertedType)
256  return nullptr;
257  signatureConverter.addInputs(argType.index(), convertedType);
258  }
259  }
260  auto newFuncOp = spirv::FuncOp::create(
261  rewriter, funcOp.getLoc(), funcOp.getName(),
262  rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
263  for (const auto &namedAttr : funcOp->getAttrs()) {
264  if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
265  namedAttr.getName() == SymbolTable::getSymbolAttrName())
266  continue;
267  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
268  }
269 
270  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
271  newFuncOp.end());
272  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
273  &signatureConverter)))
274  return nullptr;
275  rewriter.eraseOp(funcOp);
276 
277  // Set the attributes for argument and the function.
278  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
279  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
280  newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
281  }
282  newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
283 
284  return newFuncOp;
285 }
286 
287 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering
288 /// gpu.func to spirv.func if no arguments have the attributes set
289 /// already. Returns failure if any argument has the ABI attribute set already.
290 static LogicalResult
291 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
293  if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
294  return success();
295 
296  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
297  if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
299  return failure();
300  // Vulkan's interface variable requirements needs scalars to be wrapped in a
301  // struct. The struct held in storage buffer.
302  std::optional<spirv::StorageClass> sc;
303  if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
304  sc = spirv::StorageClass::StorageBuffer;
305  argABI.push_back(
306  spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
307  }
308  return success();
309 }
310 
311 LogicalResult GPUFuncOpConversion::matchAndRewrite(
312  gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
313  ConversionPatternRewriter &rewriter) const {
314  if (!gpu::GPUDialect::isKernel(funcOp))
315  return failure();
316 
317  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
319  if (failed(
320  getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
321  argABI.clear();
322  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
323  // If the ABI is already specified, use it.
324  auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
326  if (!abiAttr) {
327  funcOp.emitRemark(
328  "match failure: missing 'spirv.interface_var_abi' attribute at "
329  "argument ")
330  << argIndex;
331  return failure();
332  }
333  argABI.push_back(abiAttr);
334  }
335  }
336 
337  auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
338  if (!entryPointAttr) {
339  funcOp.emitRemark(
340  "match failure: missing 'spirv.entry_point_abi' attribute");
341  return failure();
342  }
343  spirv::FuncOp newFuncOp = lowerAsEntryFunction(
344  funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
345  if (!newFuncOp)
346  return failure();
347  newFuncOp->removeAttr(
348  rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
349  return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // ModuleOp with gpu.module.
354 //===----------------------------------------------------------------------===//
355 
356 LogicalResult GPUModuleConversion::matchAndRewrite(
357  gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const {
359  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
360  const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
361  spirv::AddressingModel addressingModel = spirv::getAddressingModel(
362  targetEnv, typeConverter->getOptions().use64bitIndex);
363  FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
364  if (failed(memoryModel))
365  return moduleOp.emitRemark(
366  "cannot deduce memory model from 'spirv.target_env'");
367 
368  // Add a keyword to the module name to avoid symbolic conflict.
369  std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
370  auto spvModule = spirv::ModuleOp::create(
371  rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
372  StringRef(spvModuleName));
373 
374  // Move the region from the module op into the SPIR-V module.
375  Region &spvModuleRegion = spvModule.getRegion();
376  rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
377  spvModuleRegion.begin());
378  // The spirv.module build method adds a block. Remove that.
379  rewriter.eraseBlock(&spvModuleRegion.back());
380 
381  // Some of the patterns call `lookupTargetEnv` during conversion and they
382  // will fail if called after GPUModuleConversion and we don't preserve
383  // `TargetEnv` attribute.
384  // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
385  if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
387  spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
388  if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
389  for (Attribute targetAttr : targets)
390  if (auto spirvTargetEnvAttr =
391  dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
392  spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
393  break;
394  }
395  }
396 
397  rewriter.eraseOp(moduleOp);
398  return success();
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // GPU return inside kernel functions to SPIR-V return.
403 //===----------------------------------------------------------------------===//
404 
405 LogicalResult GPUReturnOpConversion::matchAndRewrite(
406  gpu::ReturnOp returnOp, OpAdaptor adaptor,
407  ConversionPatternRewriter &rewriter) const {
408  if (!adaptor.getOperands().empty())
409  return failure();
410 
411  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
412  return success();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // Barrier.
417 //===----------------------------------------------------------------------===//
418 
419 LogicalResult GPUBarrierConversion::matchAndRewrite(
420  gpu::BarrierOp barrierOp, OpAdaptor adaptor,
421  ConversionPatternRewriter &rewriter) const {
422  MLIRContext *context = getContext();
423  // Both execution and memory scope should be workgroup.
424  auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
425  // Require acquire and release memory semantics for workgroup memory.
426  auto memorySemantics = spirv::MemorySemanticsAttr::get(
427  context, spirv::MemorySemantics::WorkgroupMemory |
428  spirv::MemorySemantics::AcquireRelease);
429  rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
430  memorySemantics);
431  return success();
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // Shuffle
436 //===----------------------------------------------------------------------===//
437 
438 LogicalResult GPUShuffleConversion::matchAndRewrite(
439  gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
440  ConversionPatternRewriter &rewriter) const {
441  // Require the shuffle width to be the same as the target's subgroup size,
442  // given that for SPIR-V non-uniform subgroup ops, we cannot select
443  // participating invocations.
444  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
445  unsigned subgroupSize =
446  targetEnv.getAttr().getResourceLimits().getSubgroupSize();
447  IntegerAttr widthAttr;
448  if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
449  widthAttr.getValue().getZExtValue() != subgroupSize)
450  return rewriter.notifyMatchFailure(
451  shuffleOp, "shuffle width and target subgroup size mismatch");
452 
453  assert(!adaptor.getOffset().getType().isSignedInteger() &&
454  "shuffle offset must be a signless/unsigned integer");
455 
456  Location loc = shuffleOp.getLoc();
457  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
458  Value result;
459  Value validVal;
460 
461  switch (shuffleOp.getMode()) {
462  case gpu::ShuffleMode::XOR: {
463  result = spirv::GroupNonUniformShuffleXorOp::create(
464  rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
465  validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
466  shuffleOp.getLoc(), rewriter);
467  break;
468  }
469  case gpu::ShuffleMode::IDX: {
470  result = spirv::GroupNonUniformShuffleOp::create(
471  rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
472  validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
473  shuffleOp.getLoc(), rewriter);
474  break;
475  }
476  case gpu::ShuffleMode::DOWN: {
477  result = spirv::GroupNonUniformShuffleDownOp::create(
478  rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
479 
480  Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
481  Value resultLaneId =
482  arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
483  validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
484  resultLaneId, adaptor.getWidth());
485  break;
486  }
487  case gpu::ShuffleMode::UP: {
488  result = spirv::GroupNonUniformShuffleUpOp::create(
489  rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
490 
491  Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
492  Value resultLaneId =
493  arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
494  auto i32Type = rewriter.getIntegerType(32);
495  validVal = arith::CmpIOp::create(
496  rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
497  arith::ConstantOp::create(rewriter, loc, i32Type,
498  rewriter.getIntegerAttr(i32Type, 0)));
499  break;
500  }
501  }
502 
503  rewriter.replaceOp(shuffleOp, {result, validVal});
504  return success();
505 }
506 
507 //===----------------------------------------------------------------------===//
508 // Rotate
509 //===----------------------------------------------------------------------===//
510 
511 LogicalResult GPURotateConversion::matchAndRewrite(
512  gpu::RotateOp rotateOp, OpAdaptor adaptor,
513  ConversionPatternRewriter &rewriter) const {
514  const spirv::TargetEnv &targetEnv =
515  getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
516  unsigned subgroupSize =
517  targetEnv.getAttr().getResourceLimits().getSubgroupSize();
518  unsigned width = rotateOp.getWidth();
519  if (width > subgroupSize)
520  return rewriter.notifyMatchFailure(
521  rotateOp, "rotate width is larger than target subgroup size");
522 
523  Location loc = rotateOp.getLoc();
524  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
525  Value offsetVal =
526  arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
527  Value widthVal =
528  arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
529  Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
530  rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
531  Value validVal;
532  if (width == subgroupSize) {
533  validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
534  } else {
535  IntegerAttr widthAttr = adaptor.getWidthAttr();
536  Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
537  validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
538  laneId, widthVal);
539  }
540 
541  rewriter.replaceOp(rotateOp, {rotateResult, validVal});
542  return success();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // Group ops
547 //===----------------------------------------------------------------------===//
548 
549 template <typename UniformOp, typename NonUniformOp>
551  Value arg, bool isGroup, bool isUniform,
552  std::optional<uint32_t> clusterSize) {
553  Type type = arg.getType();
554  auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
555  isGroup ? spirv::Scope::Workgroup
556  : spirv::Scope::Subgroup);
557  auto groupOp = spirv::GroupOperationAttr::get(
558  builder.getContext(), clusterSize.has_value()
559  ? spirv::GroupOperation::ClusteredReduce
560  : spirv::GroupOperation::Reduce);
561  if (isUniform) {
562  return UniformOp::create(builder, loc, type, scope, groupOp, arg)
563  .getResult();
564  }
565 
566  Value clusterSizeValue;
567  if (clusterSize.has_value())
568  clusterSizeValue = spirv::ConstantOp::create(
569  builder, loc, builder.getI32Type(),
570  builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
571 
572  return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
573  clusterSizeValue)
574  .getResult();
575 }
576 
577 static std::optional<Value>
579  gpu::AllReduceOperation opType, bool isGroup,
580  bool isUniform, std::optional<uint32_t> clusterSize) {
581  enum class ElemType { Float, Boolean, Integer };
582  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
583  std::optional<uint32_t>);
584  struct OpHandler {
585  gpu::AllReduceOperation kind;
586  ElemType elemType;
587  FuncT func;
588  };
589 
590  Type type = arg.getType();
591  ElemType elementType;
592  if (isa<FloatType>(type)) {
593  elementType = ElemType::Float;
594  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
595  elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
596  : ElemType::Integer;
597  } else {
598  return std::nullopt;
599  }
600 
601  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
602  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
603  // reduction ops. We should account possible precision requirements in this
604  // conversion.
605 
606  using ReduceType = gpu::AllReduceOperation;
607  const OpHandler handlers[] = {
608  {ReduceType::ADD, ElemType::Integer,
609  &createGroupReduceOpImpl<spirv::GroupIAddOp,
610  spirv::GroupNonUniformIAddOp>},
611  {ReduceType::ADD, ElemType::Float,
612  &createGroupReduceOpImpl<spirv::GroupFAddOp,
613  spirv::GroupNonUniformFAddOp>},
614  {ReduceType::MUL, ElemType::Integer,
615  &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
616  spirv::GroupNonUniformIMulOp>},
617  {ReduceType::MUL, ElemType::Float,
618  &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
619  spirv::GroupNonUniformFMulOp>},
620  {ReduceType::MINUI, ElemType::Integer,
621  &createGroupReduceOpImpl<spirv::GroupUMinOp,
622  spirv::GroupNonUniformUMinOp>},
623  {ReduceType::MINSI, ElemType::Integer,
624  &createGroupReduceOpImpl<spirv::GroupSMinOp,
625  spirv::GroupNonUniformSMinOp>},
626  {ReduceType::MINNUMF, ElemType::Float,
627  &createGroupReduceOpImpl<spirv::GroupFMinOp,
628  spirv::GroupNonUniformFMinOp>},
629  {ReduceType::MAXUI, ElemType::Integer,
630  &createGroupReduceOpImpl<spirv::GroupUMaxOp,
631  spirv::GroupNonUniformUMaxOp>},
632  {ReduceType::MAXSI, ElemType::Integer,
633  &createGroupReduceOpImpl<spirv::GroupSMaxOp,
634  spirv::GroupNonUniformSMaxOp>},
635  {ReduceType::MAXNUMF, ElemType::Float,
636  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
637  spirv::GroupNonUniformFMaxOp>},
638  {ReduceType::MINIMUMF, ElemType::Float,
639  &createGroupReduceOpImpl<spirv::GroupFMinOp,
640  spirv::GroupNonUniformFMinOp>},
641  {ReduceType::MAXIMUMF, ElemType::Float,
642  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
643  spirv::GroupNonUniformFMaxOp>}};
644 
645  for (const OpHandler &handler : handlers)
646  if (handler.kind == opType && elementType == handler.elemType)
647  return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
648 
649  return std::nullopt;
650 }
651 
652 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
654  : public OpConversionPattern<gpu::AllReduceOp> {
655 public:
657 
658  LogicalResult
659  matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
660  ConversionPatternRewriter &rewriter) const override {
661  auto opType = op.getOp();
662 
663  // gpu.all_reduce can have either reduction op attribute or reduction
664  // region. Only attribute version is supported.
665  if (!opType)
666  return failure();
667 
668  auto result =
669  createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
670  /*isGroup*/ true, op.getUniform(), std::nullopt);
671  if (!result)
672  return failure();
673 
674  rewriter.replaceOp(op, *result);
675  return success();
676  }
677 };
678 
679 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
681  : public OpConversionPattern<gpu::SubgroupReduceOp> {
682 public:
684 
685  LogicalResult
686  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
687  ConversionPatternRewriter &rewriter) const override {
688  if (op.getClusterStride() > 1) {
689  return rewriter.notifyMatchFailure(
690  op, "lowering for cluster stride > 1 is not implemented");
691  }
692 
693  if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
694  return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
695 
696  auto result = createGroupReduceOp(
697  rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
698  /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
699  if (!result)
700  return failure();
701 
702  rewriter.replaceOp(op, *result);
703  return success();
704  }
705 };
706 
707 // Formulate a unique variable/constant name after
708 // searching in the module for existing variable/constant names.
709 // This is to avoid name collision with existing variables.
710 // Example: printfMsg0, printfMsg1, printfMsg2, ...
711 static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
712  std::string name;
713  unsigned number = 0;
714 
715  do {
716  name.clear();
717  name = (prefix + llvm::Twine(number++)).str();
718  } while (moduleOp.lookupSymbol(name));
719 
720  return name;
721 }
722 
723 /// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
724 
725 LogicalResult GPUPrintfConversion::matchAndRewrite(
726  gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
727  ConversionPatternRewriter &rewriter) const {
728 
729  Location loc = gpuPrintfOp.getLoc();
730 
731  auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
732  if (!moduleOp)
733  return failure();
734 
735  // SPIR-V global variable is used to initialize printf
736  // format string value, if there are multiple printf messages,
737  // each global var needs to be created with a unique name.
738  std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
739  spirv::GlobalVariableOp globalVar;
740 
741  IntegerType i8Type = rewriter.getI8Type();
742  IntegerType i32Type = rewriter.getI32Type();
743 
744  // Each character of printf format string is
745  // stored as a spec constant. We need to create
746  // unique name for this spec constant like
747  // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
748  // for existing spec constant names.
749  auto createSpecConstant = [&](unsigned value) {
750  auto attr = rewriter.getI8IntegerAttr(value);
751  std::string specCstName =
752  makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
753 
754  return spirv::SpecConstantOp::create(
755  rewriter, loc, rewriter.getStringAttr(specCstName), attr);
756  };
757  {
758  Operation *parent =
759  SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
760 
761  ConversionPatternRewriter::InsertionGuard guard(rewriter);
762 
763  Block &entryBlock = *parent->getRegion(0).begin();
764  rewriter.setInsertionPointToStart(
765  &entryBlock); // insertion point at module level
766 
767  // Create Constituents with SpecConstant by scanning format string
768  // Each character of format string is stored as a spec constant
769  // and then these spec constants are used to create a
770  // SpecConstantCompositeOp.
771  llvm::SmallString<20> formatString(adaptor.getFormat());
772  formatString.push_back('\0'); // Null terminate for C.
773  SmallVector<Attribute, 4> constituents;
774  for (char c : formatString) {
775  spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
776  constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
777  }
778 
779  // Create SpecConstantCompositeOp to initialize the global variable
780  size_t contentSize = constituents.size();
781  auto globalType = spirv::ArrayType::get(i8Type, contentSize);
782  spirv::SpecConstantCompositeOp specCstComposite;
783  // There will be one SpecConstantCompositeOp per printf message/global var,
784  // so no need do lookup for existing ones.
785  std::string specCstCompositeName =
786  (llvm::Twine(globalVarName) + "_scc").str();
787 
788  specCstComposite = spirv::SpecConstantCompositeOp::create(
789  rewriter, loc, TypeAttr::get(globalType),
790  rewriter.getStringAttr(specCstCompositeName),
791  rewriter.getArrayAttr(constituents));
792 
793  auto ptrType = spirv::PointerType::get(
794  globalType, spirv::StorageClass::UniformConstant);
795 
796  // Define a GlobalVarOp initialized using specialized constants
797  // that is used to specify the printf format string
798  // to be passed to the SPIRV CLPrintfOp.
799  globalVar = spirv::GlobalVariableOp::create(
800  rewriter, loc, ptrType, globalVarName,
801  FlatSymbolRefAttr::get(specCstComposite));
802 
803  globalVar->setAttr("Constant", rewriter.getUnitAttr());
804  }
805  // Get SSA value of Global variable and create pointer to i8 to point to
806  // the format string.
807  Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
808  Value fmtStr = spirv::BitcastOp::create(
809  rewriter, loc,
810  spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
811  globalPtr);
812 
813  // Get printf arguments.
814  auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
815 
816  spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
817 
818  // Need to erase the gpu.printf op as gpu.printf does not use result vs
819  // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
820  // printf op.
821  rewriter.eraseOp(gpuPrintfOp);
822 
823  return success();
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // GPU To SPIRV Patterns.
828 //===----------------------------------------------------------------------===//
829 
832  patterns.add<
833  GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
834  GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
835  LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
836  LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
837  LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
838  LaunchConfigConversion<gpu::ThreadIdOp,
839  spirv::BuiltIn::LocalInvocationId>,
840  LaunchConfigConversion<gpu::GlobalIdOp,
841  spirv::BuiltIn::GlobalInvocationId>,
842  SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
843  spirv::BuiltIn::SubgroupId>,
844  SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
845  spirv::BuiltIn::NumSubgroups>,
846  SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
847  spirv::BuiltIn::SubgroupSize>,
848  SingleDimLaunchConfigConversion<
849  gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
850  WorkGroupSizeConversion, GPUAllReduceConversion,
851  GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
852  patterns.getContext());
853 }
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
Definition: GPUToSPIRV.cpp:291
static constexpr const char kSPIRVModule[]
Definition: GPUToSPIRV.cpp:27
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
Definition: GPUToSPIRV.cpp:578
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
Definition: GPUToSPIRV.cpp:550
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
Definition: GPUToSPIRV.cpp:231
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
Definition: GPUToSPIRV.cpp:711
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
#define MINUI(lhs, rhs)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:654
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Definition: GPUToSPIRV.cpp:659
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:681
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Definition: GPUToSPIRV.cpp:686
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:271
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
IntegerType getI8Type()
Definition: Builders.cpp:58
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:216
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
constexpr unsigned subgroupSize
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet & patterns
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
Definition: GPUToSPIRV.cpp:830
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369