MLIR  20.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static constexpr const char kSPIRVModule[] = "__spv__";
30 
31 namespace {
32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
33 /// builtin variables.
34 template <typename SourceOp, spirv::BuiltIn builtin>
35 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
36 public:
38 
39  LogicalResult
40  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override;
42 };
43 
44 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
45 /// builtin variables.
46 template <typename SourceOp, spirv::BuiltIn builtin>
47 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
48 public:
50 
51  LogicalResult
52  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override;
54 };
55 
56 /// This is separate because in Vulkan workgroup size is exposed to shaders via
57 /// a constant with WorkgroupSize decoration. So here we cannot generate a
58 /// builtin variable; instead the information in the `spirv.entry_point_abi`
59 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
60 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
61 public:
62  WorkGroupSizeConversion(const TypeConverter &typeConverter,
63  MLIRContext *context)
64  : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
65 
66  LogicalResult
67  matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override;
69 };
70 
71 /// Pattern to convert a kernel function in GPU dialect within a spirv.module.
72 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
73 public:
75 
76  LogicalResult
77  matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
78  ConversionPatternRewriter &rewriter) const override;
79 
80 private:
81  SmallVector<int32_t, 3> workGroupSizeAsInt32;
82 };
83 
84 /// Pattern to convert a gpu.module to a spirv.module.
85 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
86 public:
88 
89  LogicalResult
90  matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
91  ConversionPatternRewriter &rewriter) const override;
92 };
93 
94 /// Pattern to convert a gpu.return into a SPIR-V return.
95 // TODO: This can go to DRR when GPU return has operands.
96 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
97 public:
99 
100  LogicalResult
101  matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
102  ConversionPatternRewriter &rewriter) const override;
103 };
104 
105 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
106 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
107 public:
109 
110  LogicalResult
111  matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
112  ConversionPatternRewriter &rewriter) const override;
113 };
114 
115 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
116 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
117 public:
119 
120  LogicalResult
121  matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override;
123 };
124 
125 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126 public:
128 
129  LogicalResult
130  matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override;
132 };
133 
134 } // namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Builtins.
138 //===----------------------------------------------------------------------===//
139 
140 template <typename SourceOp, spirv::BuiltIn builtin>
141 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
142  SourceOp op, typename SourceOp::Adaptor adaptor,
143  ConversionPatternRewriter &rewriter) const {
144  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
145  Type indexType = typeConverter->getIndexType();
146 
147  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
148  // type <3xi32> by the spec:
149  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
150  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
151  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
152  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
153  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
154  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
155  //
156  // For OpenCL, it depends on the Physical32/Physical64 addressing model:
157  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
158  bool forShader =
159  typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
160  Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
161 
162  Value vector =
163  spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
164  Value dim = rewriter.create<spirv::CompositeExtractOp>(
165  op.getLoc(), builtinType, vector,
166  rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
167  if (forShader && builtinType != indexType)
168  dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
169  rewriter.replaceOp(op, dim);
170  return success();
171 }
172 
173 template <typename SourceOp, spirv::BuiltIn builtin>
174 LogicalResult
175 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
176  SourceOp op, typename SourceOp::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const {
178  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
179  Type indexType = typeConverter->getIndexType();
180  Type i32Type = rewriter.getIntegerType(32);
181 
182  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
183  // type i32 by the spec:
184  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
185  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
186  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
187  //
188  // For OpenCL, they are also required to be i32:
189  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
190  Value builtinValue =
191  spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
192  if (i32Type != indexType)
193  builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
194  builtinValue);
195  rewriter.replaceOp(op, builtinValue);
196  return success();
197 }
198 
199 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
200  gpu::BlockDimOp op, OpAdaptor adaptor,
201  ConversionPatternRewriter &rewriter) const {
202  DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
203  if (!workGroupSizeAttr)
204  return failure();
205 
206  int val =
207  workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
208  auto convertedType =
209  getTypeConverter()->convertType(op.getResult().getType());
210  if (!convertedType)
211  return failure();
212  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
213  op, convertedType, IntegerAttr::get(convertedType, val));
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // GPUFuncOp
219 //===----------------------------------------------------------------------===//
220 
221 // Legalizes a GPU function as an entry SPIR-V function.
222 static spirv::FuncOp
223 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
224  ConversionPatternRewriter &rewriter,
225  spirv::EntryPointABIAttr entryPointInfo,
227  auto fnType = funcOp.getFunctionType();
228  if (fnType.getNumResults()) {
229  funcOp.emitError("SPIR-V lowering only supports entry functions"
230  "with no return values right now");
231  return nullptr;
232  }
233  if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
234  funcOp.emitError(
235  "lowering as entry functions requires ABI info for all arguments "
236  "or none of them");
237  return nullptr;
238  }
239  // Update the signature to valid SPIR-V types and add the ABI
240  // attributes. These will be "materialized" by using the
241  // LowerABIAttributesPass.
242  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
243  {
244  for (const auto &argType :
245  enumerate(funcOp.getFunctionType().getInputs())) {
246  auto convertedType = typeConverter.convertType(argType.value());
247  if (!convertedType)
248  return nullptr;
249  signatureConverter.addInputs(argType.index(), convertedType);
250  }
251  }
252  auto newFuncOp = rewriter.create<spirv::FuncOp>(
253  funcOp.getLoc(), funcOp.getName(),
254  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
255  std::nullopt));
256  for (const auto &namedAttr : funcOp->getAttrs()) {
257  if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
258  namedAttr.getName() == SymbolTable::getSymbolAttrName())
259  continue;
260  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
261  }
262 
263  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
264  newFuncOp.end());
265  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
266  &signatureConverter)))
267  return nullptr;
268  rewriter.eraseOp(funcOp);
269 
270  // Set the attributes for argument and the function.
271  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
272  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
273  newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
274  }
275  newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
276 
277  return newFuncOp;
278 }
279 
280 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering
281 /// gpu.func to spirv.func if no arguments have the attributes set
282 /// already. Returns failure if any argument has the ABI attribute set already.
283 static LogicalResult
284 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
286  if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
287  return success();
288 
289  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
290  if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
292  return failure();
293  // Vulkan's interface variable requirements needs scalars to be wrapped in a
294  // struct. The struct held in storage buffer.
295  std::optional<spirv::StorageClass> sc;
296  if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
297  sc = spirv::StorageClass::StorageBuffer;
298  argABI.push_back(
299  spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
300  }
301  return success();
302 }
303 
304 LogicalResult GPUFuncOpConversion::matchAndRewrite(
305  gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
306  ConversionPatternRewriter &rewriter) const {
307  if (!gpu::GPUDialect::isKernel(funcOp))
308  return failure();
309 
310  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312  if (failed(
313  getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
314  argABI.clear();
315  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
316  // If the ABI is already specified, use it.
317  auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
319  if (!abiAttr) {
320  funcOp.emitRemark(
321  "match failure: missing 'spirv.interface_var_abi' attribute at "
322  "argument ")
323  << argIndex;
324  return failure();
325  }
326  argABI.push_back(abiAttr);
327  }
328  }
329 
330  auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
331  if (!entryPointAttr) {
332  funcOp.emitRemark(
333  "match failure: missing 'spirv.entry_point_abi' attribute");
334  return failure();
335  }
336  spirv::FuncOp newFuncOp = lowerAsEntryFunction(
337  funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
338  if (!newFuncOp)
339  return failure();
340  newFuncOp->removeAttr(
341  rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
342  return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // ModuleOp with gpu.module.
347 //===----------------------------------------------------------------------===//
348 
349 LogicalResult GPUModuleConversion::matchAndRewrite(
350  gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
351  ConversionPatternRewriter &rewriter) const {
352  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
353  const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
354  spirv::AddressingModel addressingModel = spirv::getAddressingModel(
355  targetEnv, typeConverter->getOptions().use64bitIndex);
356  FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
357  if (failed(memoryModel))
358  return moduleOp.emitRemark(
359  "cannot deduce memory model from 'spirv.target_env'");
360 
361  // Add a keyword to the module name to avoid symbolic conflict.
362  std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
363  auto spvModule = rewriter.create<spirv::ModuleOp>(
364  moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
365  StringRef(spvModuleName));
366 
367  // Move the region from the module op into the SPIR-V module.
368  Region &spvModuleRegion = spvModule.getRegion();
369  rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
370  spvModuleRegion.begin());
371  // The spirv.module build method adds a block. Remove that.
372  rewriter.eraseBlock(&spvModuleRegion.back());
373 
374  // Some of the patterns call `lookupTargetEnv` during conversion and they
375  // will fail if called after GPUModuleConversion and we don't preserve
376  // `TargetEnv` attribute.
377  // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
378  if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
380  spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
381 
382  rewriter.eraseOp(moduleOp);
383  return success();
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // GPU return inside kernel functions to SPIR-V return.
388 //===----------------------------------------------------------------------===//
389 
390 LogicalResult GPUReturnOpConversion::matchAndRewrite(
391  gpu::ReturnOp returnOp, OpAdaptor adaptor,
392  ConversionPatternRewriter &rewriter) const {
393  if (!adaptor.getOperands().empty())
394  return failure();
395 
396  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
397  return success();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // Barrier.
402 //===----------------------------------------------------------------------===//
403 
404 LogicalResult GPUBarrierConversion::matchAndRewrite(
405  gpu::BarrierOp barrierOp, OpAdaptor adaptor,
406  ConversionPatternRewriter &rewriter) const {
407  MLIRContext *context = getContext();
408  // Both execution and memory scope should be workgroup.
409  auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
410  // Require acquire and release memory semantics for workgroup memory.
411  auto memorySemantics = spirv::MemorySemanticsAttr::get(
412  context, spirv::MemorySemantics::WorkgroupMemory |
413  spirv::MemorySemantics::AcquireRelease);
414  rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
415  memorySemantics);
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // Shuffle
421 //===----------------------------------------------------------------------===//
422 
423 LogicalResult GPUShuffleConversion::matchAndRewrite(
424  gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
425  ConversionPatternRewriter &rewriter) const {
426  // Require the shuffle width to be the same as the target's subgroup size,
427  // given that for SPIR-V non-uniform subgroup ops, we cannot select
428  // participating invocations.
429  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
430  unsigned subgroupSize =
431  targetEnv.getAttr().getResourceLimits().getSubgroupSize();
432  IntegerAttr widthAttr;
433  if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
434  widthAttr.getValue().getZExtValue() != subgroupSize)
435  return rewriter.notifyMatchFailure(
436  shuffleOp, "shuffle width and target subgroup size mismatch");
437 
438  Location loc = shuffleOp.getLoc();
439  Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
440  shuffleOp.getLoc(), rewriter);
441  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
442  Value result;
443 
444  switch (shuffleOp.getMode()) {
445  case gpu::ShuffleMode::XOR:
446  result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
447  loc, scope, adaptor.getValue(), adaptor.getOffset());
448  break;
449  case gpu::ShuffleMode::IDX:
450  result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
451  loc, scope, adaptor.getValue(), adaptor.getOffset());
452  break;
453  default:
454  return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
455  }
456 
457  rewriter.replaceOp(shuffleOp, {result, trueVal});
458  return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // Group ops
463 //===----------------------------------------------------------------------===//
464 
465 template <typename UniformOp, typename NonUniformOp>
467  Value arg, bool isGroup, bool isUniform) {
468  Type type = arg.getType();
469  auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
470  isGroup ? spirv::Scope::Workgroup
471  : spirv::Scope::Subgroup);
472  auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
473  spirv::GroupOperation::Reduce);
474  if (isUniform) {
475  return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
476  .getResult();
477  }
478  return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
479  .getResult();
480 }
481 
482 static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
483  Location loc, Value arg,
484  gpu::AllReduceOperation opType,
485  bool isGroup, bool isUniform) {
486  enum class ElemType { Float, Boolean, Integer };
487  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
488  struct OpHandler {
489  gpu::AllReduceOperation kind;
490  ElemType elemType;
491  FuncT func;
492  };
493 
494  Type type = arg.getType();
495  ElemType elementType;
496  if (isa<FloatType>(type)) {
497  elementType = ElemType::Float;
498  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
499  elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
500  : ElemType::Integer;
501  } else {
502  return std::nullopt;
503  }
504 
505  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
506  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
507  // reduction ops. We should account possible precision requirements in this
508  // conversion.
509 
510  using ReduceType = gpu::AllReduceOperation;
511  const OpHandler handlers[] = {
512  {ReduceType::ADD, ElemType::Integer,
513  &createGroupReduceOpImpl<spirv::GroupIAddOp,
514  spirv::GroupNonUniformIAddOp>},
515  {ReduceType::ADD, ElemType::Float,
516  &createGroupReduceOpImpl<spirv::GroupFAddOp,
517  spirv::GroupNonUniformFAddOp>},
518  {ReduceType::MUL, ElemType::Integer,
519  &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
520  spirv::GroupNonUniformIMulOp>},
521  {ReduceType::MUL, ElemType::Float,
522  &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
523  spirv::GroupNonUniformFMulOp>},
524  {ReduceType::MINUI, ElemType::Integer,
525  &createGroupReduceOpImpl<spirv::GroupUMinOp,
526  spirv::GroupNonUniformUMinOp>},
527  {ReduceType::MINSI, ElemType::Integer,
528  &createGroupReduceOpImpl<spirv::GroupSMinOp,
529  spirv::GroupNonUniformSMinOp>},
530  {ReduceType::MINNUMF, ElemType::Float,
531  &createGroupReduceOpImpl<spirv::GroupFMinOp,
532  spirv::GroupNonUniformFMinOp>},
533  {ReduceType::MAXUI, ElemType::Integer,
534  &createGroupReduceOpImpl<spirv::GroupUMaxOp,
535  spirv::GroupNonUniformUMaxOp>},
536  {ReduceType::MAXSI, ElemType::Integer,
537  &createGroupReduceOpImpl<spirv::GroupSMaxOp,
538  spirv::GroupNonUniformSMaxOp>},
539  {ReduceType::MAXNUMF, ElemType::Float,
540  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
541  spirv::GroupNonUniformFMaxOp>},
542  {ReduceType::MINIMUMF, ElemType::Float,
543  &createGroupReduceOpImpl<spirv::GroupFMinOp,
544  spirv::GroupNonUniformFMinOp>},
545  {ReduceType::MAXIMUMF, ElemType::Float,
546  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
547  spirv::GroupNonUniformFMaxOp>}};
548 
549  for (const OpHandler &handler : handlers)
550  if (handler.kind == opType && elementType == handler.elemType)
551  return handler.func(builder, loc, arg, isGroup, isUniform);
552 
553  return std::nullopt;
554 }
555 
556 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
558  : public OpConversionPattern<gpu::AllReduceOp> {
559 public:
561 
562  LogicalResult
563  matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
564  ConversionPatternRewriter &rewriter) const override {
565  auto opType = op.getOp();
566 
567  // gpu.all_reduce can have either reduction op attribute or reduction
568  // region. Only attribute version is supported.
569  if (!opType)
570  return failure();
571 
572  auto result =
573  createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
574  /*isGroup*/ true, op.getUniform());
575  if (!result)
576  return failure();
577 
578  rewriter.replaceOp(op, *result);
579  return success();
580  }
581 };
582 
583 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
585  : public OpConversionPattern<gpu::SubgroupReduceOp> {
586 public:
588 
589  LogicalResult
590  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
591  ConversionPatternRewriter &rewriter) const override {
592  if (op.getClusterSize())
593  return rewriter.notifyMatchFailure(
594  op, "lowering for clustered reduce not implemented");
595 
596  if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
597  return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
598 
599  auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
600  adaptor.getOp(),
601  /*isGroup=*/false, adaptor.getUniform());
602  if (!result)
603  return failure();
604 
605  rewriter.replaceOp(op, *result);
606  return success();
607  }
608 };
609 
610 // Formulate a unique variable/constant name after
611 // searching in the module for existing variable/constant names.
612 // This is to avoid name collision with existing variables.
613 // Example: printfMsg0, printfMsg1, printfMsg2, ...
614 static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
615  std::string name;
616  unsigned number = 0;
617 
618  do {
619  name.clear();
620  name = (prefix + llvm::Twine(number++)).str();
621  } while (moduleOp.lookupSymbol(name));
622 
623  return name;
624 }
625 
626 /// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
627 
628 LogicalResult GPUPrintfConversion::matchAndRewrite(
629  gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
630  ConversionPatternRewriter &rewriter) const {
631 
632  Location loc = gpuPrintfOp.getLoc();
633 
634  auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
635  if (!moduleOp)
636  return failure();
637 
638  // SPIR-V global variable is used to initialize printf
639  // format string value, if there are multiple printf messages,
640  // each global var needs to be created with a unique name.
641  std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
642  spirv::GlobalVariableOp globalVar;
643 
644  IntegerType i8Type = rewriter.getI8Type();
645  IntegerType i32Type = rewriter.getI32Type();
646 
647  // Each character of printf format string is
648  // stored as a spec constant. We need to create
649  // unique name for this spec constant like
650  // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
651  // for existing spec constant names.
652  auto createSpecConstant = [&](unsigned value) {
653  auto attr = rewriter.getI8IntegerAttr(value);
654  std::string specCstName =
655  makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
656 
657  return rewriter.create<spirv::SpecConstantOp>(
658  loc, rewriter.getStringAttr(specCstName), attr);
659  };
660  {
661  Operation *parent =
662  SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
663 
664  ConversionPatternRewriter::InsertionGuard guard(rewriter);
665 
666  Block &entryBlock = *parent->getRegion(0).begin();
667  rewriter.setInsertionPointToStart(
668  &entryBlock); // insertion point at module level
669 
670  // Create Constituents with SpecConstant by scanning format string
671  // Each character of format string is stored as a spec constant
672  // and then these spec constants are used to create a
673  // SpecConstantCompositeOp.
674  llvm::SmallString<20> formatString(adaptor.getFormat());
675  formatString.push_back('\0'); // Null terminate for C.
676  SmallVector<Attribute, 4> constituents;
677  for (char c : formatString) {
678  spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
679  constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
680  }
681 
682  // Create SpecConstantCompositeOp to initialize the global variable
683  size_t contentSize = constituents.size();
684  auto globalType = spirv::ArrayType::get(i8Type, contentSize);
685  spirv::SpecConstantCompositeOp specCstComposite;
686  // There will be one SpecConstantCompositeOp per printf message/global var,
687  // so no need do lookup for existing ones.
688  std::string specCstCompositeName =
689  (llvm::Twine(globalVarName) + "_scc").str();
690 
691  specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
692  loc, TypeAttr::get(globalType),
693  rewriter.getStringAttr(specCstCompositeName),
694  rewriter.getArrayAttr(constituents));
695 
696  auto ptrType = spirv::PointerType::get(
697  globalType, spirv::StorageClass::UniformConstant);
698 
699  // Define a GlobalVarOp initialized using specialized constants
700  // that is used to specify the printf format string
701  // to be passed to the SPIRV CLPrintfOp.
702  globalVar = rewriter.create<spirv::GlobalVariableOp>(
703  loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
704 
705  globalVar->setAttr("Constant", rewriter.getUnitAttr());
706  }
707  // Get SSA value of Global variable and create pointer to i8 to point to
708  // the format string.
709  Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
710  Value fmtStr = rewriter.create<spirv::BitcastOp>(
711  loc,
712  spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
713  globalPtr);
714 
715  // Get printf arguments.
716  auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
717 
718  rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
719 
720  // Need to erase the gpu.printf op as gpu.printf does not use result vs
721  // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
722  // printf op.
723  rewriter.eraseOp(gpuPrintfOp);
724 
725  return success();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // GPU To SPIRV Patterns.
730 //===----------------------------------------------------------------------===//
731 
734  patterns.add<
735  GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
736  GPUReturnOpConversion, GPUShuffleConversion,
737  LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
738  LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
739  LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
740  LaunchConfigConversion<gpu::ThreadIdOp,
741  spirv::BuiltIn::LocalInvocationId>,
742  LaunchConfigConversion<gpu::GlobalIdOp,
743  spirv::BuiltIn::GlobalInvocationId>,
744  SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
745  spirv::BuiltIn::SubgroupId>,
746  SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
747  spirv::BuiltIn::NumSubgroups>,
748  SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
749  spirv::BuiltIn::SubgroupSize>,
750  WorkGroupSizeConversion, GPUAllReduceConversion,
751  GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
752  patterns.getContext());
753 }
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
Definition: GPUToSPIRV.cpp:466
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
Definition: GPUToSPIRV.cpp:284
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
Definition: GPUToSPIRV.cpp:482
static constexpr const char kSPIRVModule[]
Definition: GPUToSPIRV.cpp:29
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
Definition: GPUToSPIRV.cpp:223
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
Definition: GPUToSPIRV.cpp:614
static MLIRContext * getContext(OpFoldResult val)
#define MINUI(lhs, rhs)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:558
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: GPUToSPIRV.cpp:563
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:585
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: GPUToSPIRV.cpp:590
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:316
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
IntegerType getI8Type()
Definition: Builders.cpp:103
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:261
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:107
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet & patterns
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
Definition: GPUToSPIRV.cpp:732
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369