MLIR  20.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static constexpr const char kSPIRVModule[] = "__spv__";
30 
31 namespace {
32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
33 /// builtin variables.
34 template <typename SourceOp, spirv::BuiltIn builtin>
35 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
36 public:
38 
39  LogicalResult
40  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override;
42 };
43 
44 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
45 /// builtin variables.
46 template <typename SourceOp, spirv::BuiltIn builtin>
47 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
48 public:
50 
51  LogicalResult
52  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override;
54 };
55 
56 /// This is separate because in Vulkan workgroup size is exposed to shaders via
57 /// a constant with WorkgroupSize decoration. So here we cannot generate a
58 /// builtin variable; instead the information in the `spirv.entry_point_abi`
59 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
60 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
61 public:
62  WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context)
63  : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
64 
65  LogicalResult
66  matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
67  ConversionPatternRewriter &rewriter) const override;
68 };
69 
70 /// Pattern to convert a kernel function in GPU dialect within a spirv.module.
71 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
72 public:
74 
75  LogicalResult
76  matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
77  ConversionPatternRewriter &rewriter) const override;
78 
79 private:
80  SmallVector<int32_t, 3> workGroupSizeAsInt32;
81 };
82 
83 /// Pattern to convert a gpu.module to a spirv.module.
84 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
85 public:
87 
88  LogicalResult
89  matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter) const override;
91 };
92 
93 class GPUModuleEndConversion final
94  : public OpConversionPattern<gpu::ModuleEndOp> {
95 public:
97 
98  LogicalResult
99  matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
100  ConversionPatternRewriter &rewriter) const override {
101  rewriter.eraseOp(endOp);
102  return success();
103  }
104 };
105 
106 /// Pattern to convert a gpu.return into a SPIR-V return.
107 // TODO: This can go to DRR when GPU return has operands.
108 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
109 public:
111 
112  LogicalResult
113  matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
114  ConversionPatternRewriter &rewriter) const override;
115 };
116 
117 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
118 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
119 public:
121 
122  LogicalResult
123  matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
124  ConversionPatternRewriter &rewriter) const override;
125 };
126 
127 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
128 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
129 public:
131 
132  LogicalResult
133  matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
134  ConversionPatternRewriter &rewriter) const override;
135 };
136 
137 } // namespace
138 
139 //===----------------------------------------------------------------------===//
140 // Builtins.
141 //===----------------------------------------------------------------------===//
142 
143 template <typename SourceOp, spirv::BuiltIn builtin>
144 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
145  SourceOp op, typename SourceOp::Adaptor adaptor,
146  ConversionPatternRewriter &rewriter) const {
147  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
148  Type indexType = typeConverter->getIndexType();
149 
150  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
151  // type <3xi32> by the spec:
152  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
153  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
154  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
155  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
156  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
157  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
158  //
159  // For OpenCL, it depends on the Physical32/Physical64 addressing model:
160  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
161  bool forShader =
162  typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
163  Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
164 
165  Value vector =
166  spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
167  Value dim = rewriter.create<spirv::CompositeExtractOp>(
168  op.getLoc(), builtinType, vector,
169  rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
170  if (forShader && builtinType != indexType)
171  dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
172  rewriter.replaceOp(op, dim);
173  return success();
174 }
175 
176 template <typename SourceOp, spirv::BuiltIn builtin>
177 LogicalResult
178 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
179  SourceOp op, typename SourceOp::Adaptor adaptor,
180  ConversionPatternRewriter &rewriter) const {
181  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
182  Type indexType = typeConverter->getIndexType();
183  Type i32Type = rewriter.getIntegerType(32);
184 
185  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
186  // type i32 by the spec:
187  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
188  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
189  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
190  //
191  // For OpenCL, they are also required to be i32:
192  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
193  Value builtinValue =
194  spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
195  if (i32Type != indexType)
196  builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
197  builtinValue);
198  rewriter.replaceOp(op, builtinValue);
199  return success();
200 }
201 
202 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
203  gpu::BlockDimOp op, OpAdaptor adaptor,
204  ConversionPatternRewriter &rewriter) const {
205  DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
206  if (!workGroupSizeAttr)
207  return failure();
208 
209  int val =
210  workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
211  auto convertedType =
212  getTypeConverter()->convertType(op.getResult().getType());
213  if (!convertedType)
214  return failure();
215  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
216  op, convertedType, IntegerAttr::get(convertedType, val));
217  return success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // GPUFuncOp
222 //===----------------------------------------------------------------------===//
223 
224 // Legalizes a GPU function as an entry SPIR-V function.
225 static spirv::FuncOp
226 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
227  ConversionPatternRewriter &rewriter,
228  spirv::EntryPointABIAttr entryPointInfo,
230  auto fnType = funcOp.getFunctionType();
231  if (fnType.getNumResults()) {
232  funcOp.emitError("SPIR-V lowering only supports entry functions"
233  "with no return values right now");
234  return nullptr;
235  }
236  if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
237  funcOp.emitError(
238  "lowering as entry functions requires ABI info for all arguments "
239  "or none of them");
240  return nullptr;
241  }
242  // Update the signature to valid SPIR-V types and add the ABI
243  // attributes. These will be "materialized" by using the
244  // LowerABIAttributesPass.
245  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
246  {
247  for (const auto &argType :
248  enumerate(funcOp.getFunctionType().getInputs())) {
249  auto convertedType = typeConverter.convertType(argType.value());
250  if (!convertedType)
251  return nullptr;
252  signatureConverter.addInputs(argType.index(), convertedType);
253  }
254  }
255  auto newFuncOp = rewriter.create<spirv::FuncOp>(
256  funcOp.getLoc(), funcOp.getName(),
257  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
258  std::nullopt));
259  for (const auto &namedAttr : funcOp->getAttrs()) {
260  if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
261  namedAttr.getName() == SymbolTable::getSymbolAttrName())
262  continue;
263  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
264  }
265 
266  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
267  newFuncOp.end());
268  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
269  &signatureConverter)))
270  return nullptr;
271  rewriter.eraseOp(funcOp);
272 
273  // Set the attributes for argument and the function.
274  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
275  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
276  newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
277  }
278  newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
279 
280  return newFuncOp;
281 }
282 
283 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering
284 /// gpu.func to spirv.func if no arguments have the attributes set
285 /// already. Returns failure if any argument has the ABI attribute set already.
286 static LogicalResult
287 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
289  if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
290  return success();
291 
292  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
293  if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
295  return failure();
296  // Vulkan's interface variable requirements needs scalars to be wrapped in a
297  // struct. The struct held in storage buffer.
298  std::optional<spirv::StorageClass> sc;
299  if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
300  sc = spirv::StorageClass::StorageBuffer;
301  argABI.push_back(
302  spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
303  }
304  return success();
305 }
306 
307 LogicalResult GPUFuncOpConversion::matchAndRewrite(
308  gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
309  ConversionPatternRewriter &rewriter) const {
310  if (!gpu::GPUDialect::isKernel(funcOp))
311  return failure();
312 
313  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
315  if (failed(
316  getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
317  argABI.clear();
318  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
319  // If the ABI is already specified, use it.
320  auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
322  if (!abiAttr) {
323  funcOp.emitRemark(
324  "match failure: missing 'spirv.interface_var_abi' attribute at "
325  "argument ")
326  << argIndex;
327  return failure();
328  }
329  argABI.push_back(abiAttr);
330  }
331  }
332 
333  auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
334  if (!entryPointAttr) {
335  funcOp.emitRemark(
336  "match failure: missing 'spirv.entry_point_abi' attribute");
337  return failure();
338  }
339  spirv::FuncOp newFuncOp = lowerAsEntryFunction(
340  funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
341  if (!newFuncOp)
342  return failure();
343  newFuncOp->removeAttr(
344  rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
345  return success();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // ModuleOp with gpu.module.
350 //===----------------------------------------------------------------------===//
351 
352 LogicalResult GPUModuleConversion::matchAndRewrite(
353  gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
354  ConversionPatternRewriter &rewriter) const {
355  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
356  const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
357  spirv::AddressingModel addressingModel = spirv::getAddressingModel(
358  targetEnv, typeConverter->getOptions().use64bitIndex);
359  FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
360  if (failed(memoryModel))
361  return moduleOp.emitRemark(
362  "cannot deduce memory model from 'spirv.target_env'");
363 
364  // Add a keyword to the module name to avoid symbolic conflict.
365  std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
366  auto spvModule = rewriter.create<spirv::ModuleOp>(
367  moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
368  StringRef(spvModuleName));
369 
370  // Move the region from the module op into the SPIR-V module.
371  Region &spvModuleRegion = spvModule.getRegion();
372  rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
373  spvModuleRegion.begin());
374  // The spirv.module build method adds a block. Remove that.
375  rewriter.eraseBlock(&spvModuleRegion.back());
376 
377  // Some of the patterns call `lookupTargetEnv` during conversion and they
378  // will fail if called after GPUModuleConversion and we don't preserve
379  // `TargetEnv` attribute.
380  // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
381  if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
383  spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
384 
385  rewriter.eraseOp(moduleOp);
386  return success();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // GPU return inside kernel functions to SPIR-V return.
391 //===----------------------------------------------------------------------===//
392 
393 LogicalResult GPUReturnOpConversion::matchAndRewrite(
394  gpu::ReturnOp returnOp, OpAdaptor adaptor,
395  ConversionPatternRewriter &rewriter) const {
396  if (!adaptor.getOperands().empty())
397  return failure();
398 
399  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
400  return success();
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // Barrier.
405 //===----------------------------------------------------------------------===//
406 
407 LogicalResult GPUBarrierConversion::matchAndRewrite(
408  gpu::BarrierOp barrierOp, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const {
410  MLIRContext *context = getContext();
411  // Both execution and memory scope should be workgroup.
412  auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
413  // Require acquire and release memory semantics for workgroup memory.
414  auto memorySemantics = spirv::MemorySemanticsAttr::get(
415  context, spirv::MemorySemantics::WorkgroupMemory |
416  spirv::MemorySemantics::AcquireRelease);
417  rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
418  memorySemantics);
419  return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // Shuffle
424 //===----------------------------------------------------------------------===//
425 
426 LogicalResult GPUShuffleConversion::matchAndRewrite(
427  gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
428  ConversionPatternRewriter &rewriter) const {
429  // Require the shuffle width to be the same as the target's subgroup size,
430  // given that for SPIR-V non-uniform subgroup ops, we cannot select
431  // participating invocations.
432  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
433  unsigned subgroupSize =
434  targetEnv.getAttr().getResourceLimits().getSubgroupSize();
435  IntegerAttr widthAttr;
436  if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
437  widthAttr.getValue().getZExtValue() != subgroupSize)
438  return rewriter.notifyMatchFailure(
439  shuffleOp, "shuffle width and target subgroup size mismatch");
440 
441  Location loc = shuffleOp.getLoc();
442  Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
443  shuffleOp.getLoc(), rewriter);
444  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
445  Value result;
446 
447  switch (shuffleOp.getMode()) {
448  case gpu::ShuffleMode::XOR:
449  result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
450  loc, scope, adaptor.getValue(), adaptor.getOffset());
451  break;
452  case gpu::ShuffleMode::IDX:
453  result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
454  loc, scope, adaptor.getValue(), adaptor.getOffset());
455  break;
456  default:
457  return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
458  }
459 
460  rewriter.replaceOp(shuffleOp, {result, trueVal});
461  return success();
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // Group ops
466 //===----------------------------------------------------------------------===//
467 
468 template <typename UniformOp, typename NonUniformOp>
470  Value arg, bool isGroup, bool isUniform) {
471  Type type = arg.getType();
472  auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
473  isGroup ? spirv::Scope::Workgroup
474  : spirv::Scope::Subgroup);
475  auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
476  spirv::GroupOperation::Reduce);
477  if (isUniform) {
478  return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
479  .getResult();
480  }
481  return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
482  .getResult();
483 }
484 
485 static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
486  Location loc, Value arg,
487  gpu::AllReduceOperation opType,
488  bool isGroup, bool isUniform) {
489  enum class ElemType { Float, Boolean, Integer };
490  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
491  struct OpHandler {
492  gpu::AllReduceOperation kind;
493  ElemType elemType;
494  FuncT func;
495  };
496 
497  Type type = arg.getType();
498  ElemType elementType;
499  if (isa<FloatType>(type)) {
500  elementType = ElemType::Float;
501  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
502  elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
503  : ElemType::Integer;
504  } else {
505  return std::nullopt;
506  }
507 
508  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
509  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
510  // reduction ops. We should account possible precision requirements in this
511  // conversion.
512 
513  using ReduceType = gpu::AllReduceOperation;
514  const OpHandler handlers[] = {
515  {ReduceType::ADD, ElemType::Integer,
516  &createGroupReduceOpImpl<spirv::GroupIAddOp,
517  spirv::GroupNonUniformIAddOp>},
518  {ReduceType::ADD, ElemType::Float,
519  &createGroupReduceOpImpl<spirv::GroupFAddOp,
520  spirv::GroupNonUniformFAddOp>},
521  {ReduceType::MUL, ElemType::Integer,
522  &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
523  spirv::GroupNonUniformIMulOp>},
524  {ReduceType::MUL, ElemType::Float,
525  &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
526  spirv::GroupNonUniformFMulOp>},
527  {ReduceType::MINUI, ElemType::Integer,
528  &createGroupReduceOpImpl<spirv::GroupUMinOp,
529  spirv::GroupNonUniformUMinOp>},
530  {ReduceType::MINSI, ElemType::Integer,
531  &createGroupReduceOpImpl<spirv::GroupSMinOp,
532  spirv::GroupNonUniformSMinOp>},
533  {ReduceType::MINNUMF, ElemType::Float,
534  &createGroupReduceOpImpl<spirv::GroupFMinOp,
535  spirv::GroupNonUniformFMinOp>},
536  {ReduceType::MAXUI, ElemType::Integer,
537  &createGroupReduceOpImpl<spirv::GroupUMaxOp,
538  spirv::GroupNonUniformUMaxOp>},
539  {ReduceType::MAXSI, ElemType::Integer,
540  &createGroupReduceOpImpl<spirv::GroupSMaxOp,
541  spirv::GroupNonUniformSMaxOp>},
542  {ReduceType::MAXNUMF, ElemType::Float,
543  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
544  spirv::GroupNonUniformFMaxOp>},
545  {ReduceType::MINIMUMF, ElemType::Float,
546  &createGroupReduceOpImpl<spirv::GroupFMinOp,
547  spirv::GroupNonUniformFMinOp>},
548  {ReduceType::MAXIMUMF, ElemType::Float,
549  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
550  spirv::GroupNonUniformFMaxOp>}};
551 
552  for (const OpHandler &handler : handlers)
553  if (handler.kind == opType && elementType == handler.elemType)
554  return handler.func(builder, loc, arg, isGroup, isUniform);
555 
556  return std::nullopt;
557 }
558 
559 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
561  : public OpConversionPattern<gpu::AllReduceOp> {
562 public:
564 
565  LogicalResult
566  matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
567  ConversionPatternRewriter &rewriter) const override {
568  auto opType = op.getOp();
569 
570  // gpu.all_reduce can have either reduction op attribute or reduction
571  // region. Only attribute version is supported.
572  if (!opType)
573  return failure();
574 
575  auto result =
576  createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
577  /*isGroup*/ true, op.getUniform());
578  if (!result)
579  return failure();
580 
581  rewriter.replaceOp(op, *result);
582  return success();
583  }
584 };
585 
586 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
588  : public OpConversionPattern<gpu::SubgroupReduceOp> {
589 public:
591 
592  LogicalResult
593  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
594  ConversionPatternRewriter &rewriter) const override {
595  if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
596  return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
597 
598  auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
599  adaptor.getOp(),
600  /*isGroup=*/false, adaptor.getUniform());
601  if (!result)
602  return failure();
603 
604  rewriter.replaceOp(op, *result);
605  return success();
606  }
607 };
608 
609 //===----------------------------------------------------------------------===//
610 // GPU To SPIRV Patterns.
611 //===----------------------------------------------------------------------===//
612 
614  RewritePatternSet &patterns) {
615  patterns.add<
616  GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
617  GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
618  LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
619  LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
620  LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
621  LaunchConfigConversion<gpu::ThreadIdOp,
622  spirv::BuiltIn::LocalInvocationId>,
623  LaunchConfigConversion<gpu::GlobalIdOp,
624  spirv::BuiltIn::GlobalInvocationId>,
625  SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
626  spirv::BuiltIn::SubgroupId>,
627  SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
628  spirv::BuiltIn::NumSubgroups>,
629  SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
630  spirv::BuiltIn::SubgroupSize>,
631  WorkGroupSizeConversion, GPUAllReduceConversion,
632  GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
633 }
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
Definition: GPUToSPIRV.cpp:469
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
Definition: GPUToSPIRV.cpp:287
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
Definition: GPUToSPIRV.cpp:485
static constexpr const char kSPIRVModule[]
Definition: GPUToSPIRV.cpp:29
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
Definition: GPUToSPIRV.cpp:226
static MLIRContext * getContext(OpFoldResult val)
#define MINUI(lhs, rhs)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:561
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: GPUToSPIRV.cpp:566
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:588
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: GPUToSPIRV.cpp:593
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:287
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:100
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:77
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:101
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
An attribute that specifies the information regarding the interface variable: descriptor set,...
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
Definition: GPUToSPIRV.cpp:613
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310