MLIR  21.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static constexpr const char kSPIRVModule[] = "__spv__";
30 
31 namespace {
32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
33 /// builtin variables.
34 template <typename SourceOp, spirv::BuiltIn builtin>
35 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
36 public:
38 
39  LogicalResult
40  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override;
42 };
43 
44 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
45 /// builtin variables.
46 template <typename SourceOp, spirv::BuiltIn builtin>
47 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
48 public:
50 
51  LogicalResult
52  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override;
54 };
55 
56 /// This is separate because in Vulkan workgroup size is exposed to shaders via
57 /// a constant with WorkgroupSize decoration. So here we cannot generate a
58 /// builtin variable; instead the information in the `spirv.entry_point_abi`
59 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
60 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
61 public:
62  WorkGroupSizeConversion(const TypeConverter &typeConverter,
63  MLIRContext *context)
64  : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
65 
66  LogicalResult
67  matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override;
69 };
70 
71 /// Pattern to convert a kernel function in GPU dialect within a spirv.module.
72 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
73 public:
75 
76  LogicalResult
77  matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
78  ConversionPatternRewriter &rewriter) const override;
79 
80 private:
81  SmallVector<int32_t, 3> workGroupSizeAsInt32;
82 };
83 
84 /// Pattern to convert a gpu.module to a spirv.module.
85 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
86 public:
88 
89  LogicalResult
90  matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
91  ConversionPatternRewriter &rewriter) const override;
92 };
93 
94 /// Pattern to convert a gpu.return into a SPIR-V return.
95 // TODO: This can go to DRR when GPU return has operands.
96 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
97 public:
99 
100  LogicalResult
101  matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
102  ConversionPatternRewriter &rewriter) const override;
103 };
104 
105 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
106 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
107 public:
109 
110  LogicalResult
111  matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
112  ConversionPatternRewriter &rewriter) const override;
113 };
114 
115 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
116 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
117 public:
119 
120  LogicalResult
121  matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override;
123 };
124 
125 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126 public:
128 
129  LogicalResult
130  matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override;
132 };
133 
134 } // namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Builtins.
138 //===----------------------------------------------------------------------===//
139 
140 template <typename SourceOp, spirv::BuiltIn builtin>
141 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
142  SourceOp op, typename SourceOp::Adaptor adaptor,
143  ConversionPatternRewriter &rewriter) const {
144  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
145  Type indexType = typeConverter->getIndexType();
146 
147  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
148  // type <3xi32> by the spec:
149  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
150  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
151  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
152  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
153  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
154  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
155  //
156  // For OpenCL, it depends on the Physical32/Physical64 addressing model:
157  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
158  bool forShader =
159  typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
160  Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
161 
162  Value vector =
163  spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
164  Value dim = rewriter.create<spirv::CompositeExtractOp>(
165  op.getLoc(), builtinType, vector,
166  rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
167  if (forShader && builtinType != indexType)
168  dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
169  rewriter.replaceOp(op, dim);
170  return success();
171 }
172 
173 template <typename SourceOp, spirv::BuiltIn builtin>
174 LogicalResult
175 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
176  SourceOp op, typename SourceOp::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const {
178  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
179  Type indexType = typeConverter->getIndexType();
180  Type i32Type = rewriter.getIntegerType(32);
181 
182  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
183  // type i32 by the spec:
184  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
185  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
186  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
187  //
188  // For OpenCL, they are also required to be i32:
189  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
190  Value builtinValue =
191  spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
192  if (i32Type != indexType)
193  builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
194  builtinValue);
195  rewriter.replaceOp(op, builtinValue);
196  return success();
197 }
198 
199 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
200  gpu::BlockDimOp op, OpAdaptor adaptor,
201  ConversionPatternRewriter &rewriter) const {
202  DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
203  if (!workGroupSizeAttr)
204  return failure();
205 
206  int val =
207  workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
208  auto convertedType =
209  getTypeConverter()->convertType(op.getResult().getType());
210  if (!convertedType)
211  return failure();
212  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
213  op, convertedType, IntegerAttr::get(convertedType, val));
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // GPUFuncOp
219 //===----------------------------------------------------------------------===//
220 
221 // Legalizes a GPU function as an entry SPIR-V function.
222 static spirv::FuncOp
223 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
224  ConversionPatternRewriter &rewriter,
225  spirv::EntryPointABIAttr entryPointInfo,
227  auto fnType = funcOp.getFunctionType();
228  if (fnType.getNumResults()) {
229  funcOp.emitError("SPIR-V lowering only supports entry functions"
230  "with no return values right now");
231  return nullptr;
232  }
233  if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
234  funcOp.emitError(
235  "lowering as entry functions requires ABI info for all arguments "
236  "or none of them");
237  return nullptr;
238  }
239  // Update the signature to valid SPIR-V types and add the ABI
240  // attributes. These will be "materialized" by using the
241  // LowerABIAttributesPass.
242  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
243  {
244  for (const auto &argType :
245  enumerate(funcOp.getFunctionType().getInputs())) {
246  auto convertedType = typeConverter.convertType(argType.value());
247  if (!convertedType)
248  return nullptr;
249  signatureConverter.addInputs(argType.index(), convertedType);
250  }
251  }
252  auto newFuncOp = rewriter.create<spirv::FuncOp>(
253  funcOp.getLoc(), funcOp.getName(),
254  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
255  std::nullopt));
256  for (const auto &namedAttr : funcOp->getAttrs()) {
257  if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
258  namedAttr.getName() == SymbolTable::getSymbolAttrName())
259  continue;
260  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
261  }
262 
263  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
264  newFuncOp.end());
265  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
266  &signatureConverter)))
267  return nullptr;
268  rewriter.eraseOp(funcOp);
269 
270  // Set the attributes for argument and the function.
271  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
272  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
273  newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
274  }
275  newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
276 
277  return newFuncOp;
278 }
279 
280 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering
281 /// gpu.func to spirv.func if no arguments have the attributes set
282 /// already. Returns failure if any argument has the ABI attribute set already.
283 static LogicalResult
284 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
286  if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
287  return success();
288 
289  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
290  if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
292  return failure();
293  // Vulkan's interface variable requirements needs scalars to be wrapped in a
294  // struct. The struct held in storage buffer.
295  std::optional<spirv::StorageClass> sc;
296  if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
297  sc = spirv::StorageClass::StorageBuffer;
298  argABI.push_back(
299  spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
300  }
301  return success();
302 }
303 
304 LogicalResult GPUFuncOpConversion::matchAndRewrite(
305  gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
306  ConversionPatternRewriter &rewriter) const {
307  if (!gpu::GPUDialect::isKernel(funcOp))
308  return failure();
309 
310  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312  if (failed(
313  getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
314  argABI.clear();
315  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
316  // If the ABI is already specified, use it.
317  auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
319  if (!abiAttr) {
320  funcOp.emitRemark(
321  "match failure: missing 'spirv.interface_var_abi' attribute at "
322  "argument ")
323  << argIndex;
324  return failure();
325  }
326  argABI.push_back(abiAttr);
327  }
328  }
329 
330  auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
331  if (!entryPointAttr) {
332  funcOp.emitRemark(
333  "match failure: missing 'spirv.entry_point_abi' attribute");
334  return failure();
335  }
336  spirv::FuncOp newFuncOp = lowerAsEntryFunction(
337  funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
338  if (!newFuncOp)
339  return failure();
340  newFuncOp->removeAttr(
341  rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
342  return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // ModuleOp with gpu.module.
347 //===----------------------------------------------------------------------===//
348 
349 LogicalResult GPUModuleConversion::matchAndRewrite(
350  gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
351  ConversionPatternRewriter &rewriter) const {
352  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
353  const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
354  spirv::AddressingModel addressingModel = spirv::getAddressingModel(
355  targetEnv, typeConverter->getOptions().use64bitIndex);
356  FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
357  if (failed(memoryModel))
358  return moduleOp.emitRemark(
359  "cannot deduce memory model from 'spirv.target_env'");
360 
361  // Add a keyword to the module name to avoid symbolic conflict.
362  std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
363  auto spvModule = rewriter.create<spirv::ModuleOp>(
364  moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
365  StringRef(spvModuleName));
366 
367  // Move the region from the module op into the SPIR-V module.
368  Region &spvModuleRegion = spvModule.getRegion();
369  rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
370  spvModuleRegion.begin());
371  // The spirv.module build method adds a block. Remove that.
372  rewriter.eraseBlock(&spvModuleRegion.back());
373 
374  // Some of the patterns call `lookupTargetEnv` during conversion and they
375  // will fail if called after GPUModuleConversion and we don't preserve
376  // `TargetEnv` attribute.
377  // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
378  if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
380  spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
381 
382  rewriter.eraseOp(moduleOp);
383  return success();
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // GPU return inside kernel functions to SPIR-V return.
388 //===----------------------------------------------------------------------===//
389 
390 LogicalResult GPUReturnOpConversion::matchAndRewrite(
391  gpu::ReturnOp returnOp, OpAdaptor adaptor,
392  ConversionPatternRewriter &rewriter) const {
393  if (!adaptor.getOperands().empty())
394  return failure();
395 
396  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
397  return success();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // Barrier.
402 //===----------------------------------------------------------------------===//
403 
404 LogicalResult GPUBarrierConversion::matchAndRewrite(
405  gpu::BarrierOp barrierOp, OpAdaptor adaptor,
406  ConversionPatternRewriter &rewriter) const {
407  MLIRContext *context = getContext();
408  // Both execution and memory scope should be workgroup.
409  auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
410  // Require acquire and release memory semantics for workgroup memory.
411  auto memorySemantics = spirv::MemorySemanticsAttr::get(
412  context, spirv::MemorySemantics::WorkgroupMemory |
413  spirv::MemorySemantics::AcquireRelease);
414  rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
415  memorySemantics);
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // Shuffle
421 //===----------------------------------------------------------------------===//
422 
423 LogicalResult GPUShuffleConversion::matchAndRewrite(
424  gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
425  ConversionPatternRewriter &rewriter) const {
426  // Require the shuffle width to be the same as the target's subgroup size,
427  // given that for SPIR-V non-uniform subgroup ops, we cannot select
428  // participating invocations.
429  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
430  unsigned subgroupSize =
431  targetEnv.getAttr().getResourceLimits().getSubgroupSize();
432  IntegerAttr widthAttr;
433  if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
434  widthAttr.getValue().getZExtValue() != subgroupSize)
435  return rewriter.notifyMatchFailure(
436  shuffleOp, "shuffle width and target subgroup size mismatch");
437 
438  assert(!adaptor.getOffset().getType().isSignedInteger() &&
439  "shuffle offset must be a signless/unsigned integer");
440 
441  Location loc = shuffleOp.getLoc();
442  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
443  Value result;
444  Value validVal;
445 
446  switch (shuffleOp.getMode()) {
447  case gpu::ShuffleMode::XOR: {
448  result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
449  loc, scope, adaptor.getValue(), adaptor.getOffset());
450  validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
451  shuffleOp.getLoc(), rewriter);
452  break;
453  }
454  case gpu::ShuffleMode::IDX: {
455  result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
456  loc, scope, adaptor.getValue(), adaptor.getOffset());
457  validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
458  shuffleOp.getLoc(), rewriter);
459  break;
460  }
461  case gpu::ShuffleMode::DOWN: {
462  result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
463  loc, scope, adaptor.getValue(), adaptor.getOffset());
464 
465  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
466  Value resultLaneId =
467  rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
468  validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
469  resultLaneId, adaptor.getWidth());
470  break;
471  }
472  case gpu::ShuffleMode::UP: {
473  result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
474  loc, scope, adaptor.getValue(), adaptor.getOffset());
475 
476  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
477  Value resultLaneId =
478  rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
479  auto i32Type = rewriter.getIntegerType(32);
480  validVal = rewriter.create<arith::CmpIOp>(
481  loc, arith::CmpIPredicate::sge, resultLaneId,
482  rewriter.create<arith::ConstantOp>(
483  loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
484  break;
485  }
486  }
487 
488  rewriter.replaceOp(shuffleOp, {result, validVal});
489  return success();
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // Group ops
494 //===----------------------------------------------------------------------===//
495 
496 template <typename UniformOp, typename NonUniformOp>
498  Value arg, bool isGroup, bool isUniform,
499  std::optional<uint32_t> clusterSize) {
500  Type type = arg.getType();
501  auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
502  isGroup ? spirv::Scope::Workgroup
503  : spirv::Scope::Subgroup);
504  auto groupOp = spirv::GroupOperationAttr::get(
505  builder.getContext(), clusterSize.has_value()
506  ? spirv::GroupOperation::ClusteredReduce
507  : spirv::GroupOperation::Reduce);
508  if (isUniform) {
509  return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
510  .getResult();
511  }
512 
513  Value clusterSizeValue;
514  if (clusterSize.has_value())
515  clusterSizeValue = builder.create<spirv::ConstantOp>(
516  loc, builder.getI32Type(),
517  builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
518 
519  return builder
520  .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
521  .getResult();
522 }
523 
524 static std::optional<Value>
526  gpu::AllReduceOperation opType, bool isGroup,
527  bool isUniform, std::optional<uint32_t> clusterSize) {
528  enum class ElemType { Float, Boolean, Integer };
529  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
530  std::optional<uint32_t>);
531  struct OpHandler {
532  gpu::AllReduceOperation kind;
533  ElemType elemType;
534  FuncT func;
535  };
536 
537  Type type = arg.getType();
538  ElemType elementType;
539  if (isa<FloatType>(type)) {
540  elementType = ElemType::Float;
541  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
542  elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
543  : ElemType::Integer;
544  } else {
545  return std::nullopt;
546  }
547 
548  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
549  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
550  // reduction ops. We should account possible precision requirements in this
551  // conversion.
552 
553  using ReduceType = gpu::AllReduceOperation;
554  const OpHandler handlers[] = {
555  {ReduceType::ADD, ElemType::Integer,
556  &createGroupReduceOpImpl<spirv::GroupIAddOp,
557  spirv::GroupNonUniformIAddOp>},
558  {ReduceType::ADD, ElemType::Float,
559  &createGroupReduceOpImpl<spirv::GroupFAddOp,
560  spirv::GroupNonUniformFAddOp>},
561  {ReduceType::MUL, ElemType::Integer,
562  &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
563  spirv::GroupNonUniformIMulOp>},
564  {ReduceType::MUL, ElemType::Float,
565  &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
566  spirv::GroupNonUniformFMulOp>},
567  {ReduceType::MINUI, ElemType::Integer,
568  &createGroupReduceOpImpl<spirv::GroupUMinOp,
569  spirv::GroupNonUniformUMinOp>},
570  {ReduceType::MINSI, ElemType::Integer,
571  &createGroupReduceOpImpl<spirv::GroupSMinOp,
572  spirv::GroupNonUniformSMinOp>},
573  {ReduceType::MINNUMF, ElemType::Float,
574  &createGroupReduceOpImpl<spirv::GroupFMinOp,
575  spirv::GroupNonUniformFMinOp>},
576  {ReduceType::MAXUI, ElemType::Integer,
577  &createGroupReduceOpImpl<spirv::GroupUMaxOp,
578  spirv::GroupNonUniformUMaxOp>},
579  {ReduceType::MAXSI, ElemType::Integer,
580  &createGroupReduceOpImpl<spirv::GroupSMaxOp,
581  spirv::GroupNonUniformSMaxOp>},
582  {ReduceType::MAXNUMF, ElemType::Float,
583  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
584  spirv::GroupNonUniformFMaxOp>},
585  {ReduceType::MINIMUMF, ElemType::Float,
586  &createGroupReduceOpImpl<spirv::GroupFMinOp,
587  spirv::GroupNonUniformFMinOp>},
588  {ReduceType::MAXIMUMF, ElemType::Float,
589  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
590  spirv::GroupNonUniformFMaxOp>}};
591 
592  for (const OpHandler &handler : handlers)
593  if (handler.kind == opType && elementType == handler.elemType)
594  return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
595 
596  return std::nullopt;
597 }
598 
599 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
601  : public OpConversionPattern<gpu::AllReduceOp> {
602 public:
604 
605  LogicalResult
606  matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
607  ConversionPatternRewriter &rewriter) const override {
608  auto opType = op.getOp();
609 
610  // gpu.all_reduce can have either reduction op attribute or reduction
611  // region. Only attribute version is supported.
612  if (!opType)
613  return failure();
614 
615  auto result =
616  createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
617  /*isGroup*/ true, op.getUniform(), std::nullopt);
618  if (!result)
619  return failure();
620 
621  rewriter.replaceOp(op, *result);
622  return success();
623  }
624 };
625 
626 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
628  : public OpConversionPattern<gpu::SubgroupReduceOp> {
629 public:
631 
632  LogicalResult
633  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
634  ConversionPatternRewriter &rewriter) const override {
635  if (op.getClusterStride() > 1) {
636  return rewriter.notifyMatchFailure(
637  op, "lowering for cluster stride > 1 is not implemented");
638  }
639 
640  if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
641  return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
642 
643  auto result = createGroupReduceOp(
644  rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
645  /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
646  if (!result)
647  return failure();
648 
649  rewriter.replaceOp(op, *result);
650  return success();
651  }
652 };
653 
654 // Formulate a unique variable/constant name after
655 // searching in the module for existing variable/constant names.
656 // This is to avoid name collision with existing variables.
657 // Example: printfMsg0, printfMsg1, printfMsg2, ...
658 static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
659  std::string name;
660  unsigned number = 0;
661 
662  do {
663  name.clear();
664  name = (prefix + llvm::Twine(number++)).str();
665  } while (moduleOp.lookupSymbol(name));
666 
667  return name;
668 }
669 
670 /// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
671 
672 LogicalResult GPUPrintfConversion::matchAndRewrite(
673  gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
674  ConversionPatternRewriter &rewriter) const {
675 
676  Location loc = gpuPrintfOp.getLoc();
677 
678  auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
679  if (!moduleOp)
680  return failure();
681 
682  // SPIR-V global variable is used to initialize printf
683  // format string value, if there are multiple printf messages,
684  // each global var needs to be created with a unique name.
685  std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
686  spirv::GlobalVariableOp globalVar;
687 
688  IntegerType i8Type = rewriter.getI8Type();
689  IntegerType i32Type = rewriter.getI32Type();
690 
691  // Each character of printf format string is
692  // stored as a spec constant. We need to create
693  // unique name for this spec constant like
694  // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
695  // for existing spec constant names.
696  auto createSpecConstant = [&](unsigned value) {
697  auto attr = rewriter.getI8IntegerAttr(value);
698  std::string specCstName =
699  makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
700 
701  return rewriter.create<spirv::SpecConstantOp>(
702  loc, rewriter.getStringAttr(specCstName), attr);
703  };
704  {
705  Operation *parent =
706  SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
707 
708  ConversionPatternRewriter::InsertionGuard guard(rewriter);
709 
710  Block &entryBlock = *parent->getRegion(0).begin();
711  rewriter.setInsertionPointToStart(
712  &entryBlock); // insertion point at module level
713 
714  // Create Constituents with SpecConstant by scanning format string
715  // Each character of format string is stored as a spec constant
716  // and then these spec constants are used to create a
717  // SpecConstantCompositeOp.
718  llvm::SmallString<20> formatString(adaptor.getFormat());
719  formatString.push_back('\0'); // Null terminate for C.
720  SmallVector<Attribute, 4> constituents;
721  for (char c : formatString) {
722  spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
723  constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
724  }
725 
726  // Create SpecConstantCompositeOp to initialize the global variable
727  size_t contentSize = constituents.size();
728  auto globalType = spirv::ArrayType::get(i8Type, contentSize);
729  spirv::SpecConstantCompositeOp specCstComposite;
730  // There will be one SpecConstantCompositeOp per printf message/global var,
731  // so no need do lookup for existing ones.
732  std::string specCstCompositeName =
733  (llvm::Twine(globalVarName) + "_scc").str();
734 
735  specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
736  loc, TypeAttr::get(globalType),
737  rewriter.getStringAttr(specCstCompositeName),
738  rewriter.getArrayAttr(constituents));
739 
740  auto ptrType = spirv::PointerType::get(
741  globalType, spirv::StorageClass::UniformConstant);
742 
743  // Define a GlobalVarOp initialized using specialized constants
744  // that is used to specify the printf format string
745  // to be passed to the SPIRV CLPrintfOp.
746  globalVar = rewriter.create<spirv::GlobalVariableOp>(
747  loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
748 
749  globalVar->setAttr("Constant", rewriter.getUnitAttr());
750  }
751  // Get SSA value of Global variable and create pointer to i8 to point to
752  // the format string.
753  Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
754  Value fmtStr = rewriter.create<spirv::BitcastOp>(
755  loc,
756  spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
757  globalPtr);
758 
759  // Get printf arguments.
760  auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
761 
762  rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
763 
764  // Need to erase the gpu.printf op as gpu.printf does not use result vs
765  // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
766  // printf op.
767  rewriter.eraseOp(gpuPrintfOp);
768 
769  return success();
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // GPU To SPIRV Patterns.
774 //===----------------------------------------------------------------------===//
775 
778  patterns.add<
779  GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
780  GPUReturnOpConversion, GPUShuffleConversion,
781  LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
782  LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
783  LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
784  LaunchConfigConversion<gpu::ThreadIdOp,
785  spirv::BuiltIn::LocalInvocationId>,
786  LaunchConfigConversion<gpu::GlobalIdOp,
787  spirv::BuiltIn::GlobalInvocationId>,
788  SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
789  spirv::BuiltIn::SubgroupId>,
790  SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
791  spirv::BuiltIn::NumSubgroups>,
792  SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
793  spirv::BuiltIn::SubgroupSize>,
794  SingleDimLaunchConfigConversion<
795  gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
796  WorkGroupSizeConversion, GPUAllReduceConversion,
797  GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
798  patterns.getContext());
799 }
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
Definition: GPUToSPIRV.cpp:284
static constexpr const char kSPIRVModule[]
Definition: GPUToSPIRV.cpp:29
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
Definition: GPUToSPIRV.cpp:525
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
Definition: GPUToSPIRV.cpp:497
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
Definition: GPUToSPIRV.cpp:223
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
Definition: GPUToSPIRV.cpp:658
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
#define MINUI(lhs, rhs)
constexpr unsigned subgroupSize
HW dependent constants.
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:601
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Definition: GPUToSPIRV.cpp:606
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:628
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Definition: GPUToSPIRV.cpp:633
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:96
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:274
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:78
IntegerType getI32Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
IntegerType getI8Type()
Definition: Builders.cpp:61
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:219
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet & patterns
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
Definition: GPUToSPIRV.cpp:776
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369