MLIR  19.0.0git
GPUToSPIRV.cpp
Go to the documentation of this file.
1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
26 #include <optional>
27 
28 using namespace mlir;
29 
30 static constexpr const char kSPIRVModule[] = "__spv__";
31 
32 namespace {
33 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
34 /// builtin variables.
35 template <typename SourceOp, spirv::BuiltIn builtin>
36 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
37 public:
39 
41  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
42  ConversionPatternRewriter &rewriter) const override;
43 };
44 
45 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
46 /// builtin variables.
47 template <typename SourceOp, spirv::BuiltIn builtin>
48 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
49 public:
51 
53  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
54  ConversionPatternRewriter &rewriter) const override;
55 };
56 
57 /// This is separate because in Vulkan workgroup size is exposed to shaders via
58 /// a constant with WorkgroupSize decoration. So here we cannot generate a
59 /// builtin variable; instead the information in the `spirv.entry_point_abi`
60 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
61 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
62 public:
63  WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context)
64  : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
65 
67  matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override;
69 };
70 
71 /// Pattern to convert a kernel function in GPU dialect within a spirv.module.
72 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
73 public:
75 
77  matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
78  ConversionPatternRewriter &rewriter) const override;
79 
80 private:
81  SmallVector<int32_t, 3> workGroupSizeAsInt32;
82 };
83 
84 /// Pattern to convert a gpu.module to a spirv.module.
85 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
86 public:
88 
90  matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
91  ConversionPatternRewriter &rewriter) const override;
92 };
93 
94 class GPUModuleEndConversion final
95  : public OpConversionPattern<gpu::ModuleEndOp> {
96 public:
98 
100  matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
101  ConversionPatternRewriter &rewriter) const override {
102  rewriter.eraseOp(endOp);
103  return success();
104  }
105 };
106 
107 /// Pattern to convert a gpu.return into a SPIR-V return.
108 // TODO: This can go to DRR when GPU return has operands.
109 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
110 public:
112 
114  matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
115  ConversionPatternRewriter &rewriter) const override;
116 };
117 
118 /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
119 class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
120 public:
122 
124  matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
125  ConversionPatternRewriter &rewriter) const override;
126 };
127 
128 /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
129 class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
130 public:
132 
134  matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override;
136 };
137 
138 } // namespace
139 
140 //===----------------------------------------------------------------------===//
141 // Builtins.
142 //===----------------------------------------------------------------------===//
143 
144 template <typename SourceOp, spirv::BuiltIn builtin>
145 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
146  SourceOp op, typename SourceOp::Adaptor adaptor,
147  ConversionPatternRewriter &rewriter) const {
148  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
149  Type indexType = typeConverter->getIndexType();
150 
151  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
152  // type <3xi32> by the spec:
153  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
154  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
155  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
156  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
157  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
158  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
159  //
160  // For OpenCL, it depends on the Physical32/Physical64 addressing model:
161  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
162  bool forShader =
163  typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
164  Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
165 
166  Value vector =
167  spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
168  Value dim = rewriter.create<spirv::CompositeExtractOp>(
169  op.getLoc(), builtinType, vector,
170  rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
171  if (forShader && builtinType != indexType)
172  dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
173  rewriter.replaceOp(op, dim);
174  return success();
175 }
176 
177 template <typename SourceOp, spirv::BuiltIn builtin>
179 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
180  SourceOp op, typename SourceOp::Adaptor adaptor,
181  ConversionPatternRewriter &rewriter) const {
182  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
183  Type indexType = typeConverter->getIndexType();
184  Type i32Type = rewriter.getIntegerType(32);
185 
186  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
187  // type i32 by the spec:
188  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
189  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
190  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
191  //
192  // For OpenCL, they are also required to be i32:
193  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
194  Value builtinValue =
195  spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
196  if (i32Type != indexType)
197  builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
198  builtinValue);
199  rewriter.replaceOp(op, builtinValue);
200  return success();
201 }
202 
203 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
204  gpu::BlockDimOp op, OpAdaptor adaptor,
205  ConversionPatternRewriter &rewriter) const {
206  DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
207  if (!workGroupSizeAttr)
208  return failure();
209 
210  int val =
211  workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
212  auto convertedType =
213  getTypeConverter()->convertType(op.getResult().getType());
214  if (!convertedType)
215  return failure();
216  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
217  op, convertedType, IntegerAttr::get(convertedType, val));
218  return success();
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // GPUFuncOp
223 //===----------------------------------------------------------------------===//
224 
225 // Legalizes a GPU function as an entry SPIR-V function.
226 static spirv::FuncOp
227 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
228  ConversionPatternRewriter &rewriter,
229  spirv::EntryPointABIAttr entryPointInfo,
231  auto fnType = funcOp.getFunctionType();
232  if (fnType.getNumResults()) {
233  funcOp.emitError("SPIR-V lowering only supports entry functions"
234  "with no return values right now");
235  return nullptr;
236  }
237  if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
238  funcOp.emitError(
239  "lowering as entry functions requires ABI info for all arguments "
240  "or none of them");
241  return nullptr;
242  }
243  // Update the signature to valid SPIR-V types and add the ABI
244  // attributes. These will be "materialized" by using the
245  // LowerABIAttributesPass.
246  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
247  {
248  for (const auto &argType :
249  enumerate(funcOp.getFunctionType().getInputs())) {
250  auto convertedType = typeConverter.convertType(argType.value());
251  if (!convertedType)
252  return nullptr;
253  signatureConverter.addInputs(argType.index(), convertedType);
254  }
255  }
256  auto newFuncOp = rewriter.create<spirv::FuncOp>(
257  funcOp.getLoc(), funcOp.getName(),
258  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
259  std::nullopt));
260  for (const auto &namedAttr : funcOp->getAttrs()) {
261  if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
262  namedAttr.getName() == SymbolTable::getSymbolAttrName())
263  continue;
264  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
265  }
266 
267  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
268  newFuncOp.end());
269  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
270  &signatureConverter)))
271  return nullptr;
272  rewriter.eraseOp(funcOp);
273 
274  // Set the attributes for argument and the function.
275  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
276  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
277  newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
278  }
279  newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
280 
281  return newFuncOp;
282 }
283 
284 /// Populates `argABI` with spirv.interface_var_abi attributes for lowering
285 /// gpu.func to spirv.func if no arguments have the attributes set
286 /// already. Returns failure if any argument has the ABI attribute set already.
287 static LogicalResult
288 getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
290  if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
291  return success();
292 
293  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
294  if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
296  return failure();
297  // Vulkan's interface variable requirements needs scalars to be wrapped in a
298  // struct. The struct held in storage buffer.
299  std::optional<spirv::StorageClass> sc;
300  if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
301  sc = spirv::StorageClass::StorageBuffer;
302  argABI.push_back(
303  spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
304  }
305  return success();
306 }
307 
308 LogicalResult GPUFuncOpConversion::matchAndRewrite(
309  gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
310  ConversionPatternRewriter &rewriter) const {
311  if (!gpu::GPUDialect::isKernel(funcOp))
312  return failure();
313 
314  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
316  if (failed(
317  getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
318  argABI.clear();
319  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
320  // If the ABI is already specified, use it.
321  auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
323  if (!abiAttr) {
324  funcOp.emitRemark(
325  "match failure: missing 'spirv.interface_var_abi' attribute at "
326  "argument ")
327  << argIndex;
328  return failure();
329  }
330  argABI.push_back(abiAttr);
331  }
332  }
333 
334  auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
335  if (!entryPointAttr) {
336  funcOp.emitRemark(
337  "match failure: missing 'spirv.entry_point_abi' attribute");
338  return failure();
339  }
340  spirv::FuncOp newFuncOp = lowerAsEntryFunction(
341  funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
342  if (!newFuncOp)
343  return failure();
344  newFuncOp->removeAttr(
345  rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
346  return success();
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // ModuleOp with gpu.module.
351 //===----------------------------------------------------------------------===//
352 
353 LogicalResult GPUModuleConversion::matchAndRewrite(
354  gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
355  ConversionPatternRewriter &rewriter) const {
356  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
357  const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
358  spirv::AddressingModel addressingModel = spirv::getAddressingModel(
359  targetEnv, typeConverter->getOptions().use64bitIndex);
360  FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
361  if (failed(memoryModel))
362  return moduleOp.emitRemark(
363  "cannot deduce memory model from 'spirv.target_env'");
364 
365  // Add a keyword to the module name to avoid symbolic conflict.
366  std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
367  auto spvModule = rewriter.create<spirv::ModuleOp>(
368  moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
369  StringRef(spvModuleName));
370 
371  // Move the region from the module op into the SPIR-V module.
372  Region &spvModuleRegion = spvModule.getRegion();
373  rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
374  spvModuleRegion.begin());
375  // The spirv.module build method adds a block. Remove that.
376  rewriter.eraseBlock(&spvModuleRegion.back());
377 
378  // Some of the patterns call `lookupTargetEnv` during conversion and they
379  // will fail if called after GPUModuleConversion and we don't preserve
380  // `TargetEnv` attribute.
381  // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
382  if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
384  spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
385 
386  rewriter.eraseOp(moduleOp);
387  return success();
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // GPU return inside kernel functions to SPIR-V return.
392 //===----------------------------------------------------------------------===//
393 
394 LogicalResult GPUReturnOpConversion::matchAndRewrite(
395  gpu::ReturnOp returnOp, OpAdaptor adaptor,
396  ConversionPatternRewriter &rewriter) const {
397  if (!adaptor.getOperands().empty())
398  return failure();
399 
400  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
401  return success();
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // Barrier.
406 //===----------------------------------------------------------------------===//
407 
408 LogicalResult GPUBarrierConversion::matchAndRewrite(
409  gpu::BarrierOp barrierOp, OpAdaptor adaptor,
410  ConversionPatternRewriter &rewriter) const {
411  MLIRContext *context = getContext();
412  // Both execution and memory scope should be workgroup.
413  auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
414  // Require acquire and release memory semantics for workgroup memory.
415  auto memorySemantics = spirv::MemorySemanticsAttr::get(
416  context, spirv::MemorySemantics::WorkgroupMemory |
417  spirv::MemorySemantics::AcquireRelease);
418  rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
419  memorySemantics);
420  return success();
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // Shuffle
425 //===----------------------------------------------------------------------===//
426 
427 LogicalResult GPUShuffleConversion::matchAndRewrite(
428  gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
429  ConversionPatternRewriter &rewriter) const {
430  // Require the shuffle width to be the same as the target's subgroup size,
431  // given that for SPIR-V non-uniform subgroup ops, we cannot select
432  // participating invocations.
433  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
434  unsigned subgroupSize =
435  targetEnv.getAttr().getResourceLimits().getSubgroupSize();
436  IntegerAttr widthAttr;
437  if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
438  widthAttr.getValue().getZExtValue() != subgroupSize)
439  return rewriter.notifyMatchFailure(
440  shuffleOp, "shuffle width and target subgroup size mismatch");
441 
442  Location loc = shuffleOp.getLoc();
443  Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
444  shuffleOp.getLoc(), rewriter);
445  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
446  Value result;
447 
448  switch (shuffleOp.getMode()) {
449  case gpu::ShuffleMode::XOR:
450  result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
451  loc, scope, adaptor.getValue(), adaptor.getOffset());
452  break;
453  case gpu::ShuffleMode::IDX:
454  result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
455  loc, scope, adaptor.getValue(), adaptor.getOffset());
456  break;
457  default:
458  return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
459  }
460 
461  rewriter.replaceOp(shuffleOp, {result, trueVal});
462  return success();
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // Group ops
467 //===----------------------------------------------------------------------===//
468 
469 template <typename UniformOp, typename NonUniformOp>
471  Value arg, bool isGroup, bool isUniform) {
472  Type type = arg.getType();
473  auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
474  isGroup ? spirv::Scope::Workgroup
475  : spirv::Scope::Subgroup);
476  auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
477  spirv::GroupOperation::Reduce);
478  if (isUniform) {
479  return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
480  .getResult();
481  }
482  return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
483  .getResult();
484 }
485 
486 static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
487  Location loc, Value arg,
488  gpu::AllReduceOperation opType,
489  bool isGroup, bool isUniform) {
490  enum class ElemType { Float, Boolean, Integer };
491  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
492  struct OpHandler {
493  gpu::AllReduceOperation kind;
494  ElemType elemType;
495  FuncT func;
496  };
497 
498  Type type = arg.getType();
499  ElemType elementType;
500  if (isa<FloatType>(type)) {
501  elementType = ElemType::Float;
502  } else if (auto intTy = dyn_cast<IntegerType>(type)) {
503  elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
504  : ElemType::Integer;
505  } else {
506  return std::nullopt;
507  }
508 
509  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
510  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
511  // reduction ops. We should account possible precision requirements in this
512  // conversion.
513 
514  using ReduceType = gpu::AllReduceOperation;
515  const OpHandler handlers[] = {
516  {ReduceType::ADD, ElemType::Integer,
517  &createGroupReduceOpImpl<spirv::GroupIAddOp,
518  spirv::GroupNonUniformIAddOp>},
519  {ReduceType::ADD, ElemType::Float,
520  &createGroupReduceOpImpl<spirv::GroupFAddOp,
521  spirv::GroupNonUniformFAddOp>},
522  {ReduceType::MUL, ElemType::Integer,
523  &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
524  spirv::GroupNonUniformIMulOp>},
525  {ReduceType::MUL, ElemType::Float,
526  &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
527  spirv::GroupNonUniformFMulOp>},
528  {ReduceType::MINUI, ElemType::Integer,
529  &createGroupReduceOpImpl<spirv::GroupUMinOp,
530  spirv::GroupNonUniformUMinOp>},
531  {ReduceType::MINSI, ElemType::Integer,
532  &createGroupReduceOpImpl<spirv::GroupSMinOp,
533  spirv::GroupNonUniformSMinOp>},
534  {ReduceType::MINNUMF, ElemType::Float,
535  &createGroupReduceOpImpl<spirv::GroupFMinOp,
536  spirv::GroupNonUniformFMinOp>},
537  {ReduceType::MAXUI, ElemType::Integer,
538  &createGroupReduceOpImpl<spirv::GroupUMaxOp,
539  spirv::GroupNonUniformUMaxOp>},
540  {ReduceType::MAXSI, ElemType::Integer,
541  &createGroupReduceOpImpl<spirv::GroupSMaxOp,
542  spirv::GroupNonUniformSMaxOp>},
543  {ReduceType::MAXNUMF, ElemType::Float,
544  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
545  spirv::GroupNonUniformFMaxOp>},
546  {ReduceType::MINIMUMF, ElemType::Float,
547  &createGroupReduceOpImpl<spirv::GroupFMinOp,
548  spirv::GroupNonUniformFMinOp>},
549  {ReduceType::MAXIMUMF, ElemType::Float,
550  &createGroupReduceOpImpl<spirv::GroupFMaxOp,
551  spirv::GroupNonUniformFMaxOp>}};
552 
553  for (const OpHandler &handler : handlers)
554  if (handler.kind == opType && elementType == handler.elemType)
555  return handler.func(builder, loc, arg, isGroup, isUniform);
556 
557  return std::nullopt;
558 }
559 
560 /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
562  : public OpConversionPattern<gpu::AllReduceOp> {
563 public:
565 
567  matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
568  ConversionPatternRewriter &rewriter) const override {
569  auto opType = op.getOp();
570 
571  // gpu.all_reduce can have either reduction op attribute or reduction
572  // region. Only attribute version is supported.
573  if (!opType)
574  return failure();
575 
576  auto result =
577  createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
578  /*isGroup*/ true, op.getUniform());
579  if (!result)
580  return failure();
581 
582  rewriter.replaceOp(op, *result);
583  return success();
584  }
585 };
586 
587 /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
589  : public OpConversionPattern<gpu::SubgroupReduceOp> {
590 public:
592 
594  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
595  ConversionPatternRewriter &rewriter) const override {
596  if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
597  return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
598 
599  auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
600  adaptor.getOp(),
601  /*isGroup=*/false, adaptor.getUniform());
602  if (!result)
603  return failure();
604 
605  rewriter.replaceOp(op, *result);
606  return success();
607  }
608 };
609 
610 //===----------------------------------------------------------------------===//
611 // GPU To SPIRV Patterns.
612 //===----------------------------------------------------------------------===//
613 
615  RewritePatternSet &patterns) {
616  patterns.add<
617  GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
618  GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
619  LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
620  LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
621  LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
622  LaunchConfigConversion<gpu::ThreadIdOp,
623  spirv::BuiltIn::LocalInvocationId>,
624  LaunchConfigConversion<gpu::GlobalIdOp,
625  spirv::BuiltIn::GlobalInvocationId>,
626  SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
627  spirv::BuiltIn::SubgroupId>,
628  SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
629  spirv::BuiltIn::NumSubgroups>,
630  SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
631  spirv::BuiltIn::SubgroupSize>,
632  WorkGroupSizeConversion, GPUAllReduceConversion,
633  GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
634 }
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
Definition: GPUToSPIRV.cpp:470
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
Definition: GPUToSPIRV.cpp:288
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
Definition: GPUToSPIRV.cpp:486
static constexpr const char kSPIRVModule[]
Definition: GPUToSPIRV.cpp:30
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
Definition: GPUToSPIRV.cpp:227
static MLIRContext * getContext(OpFoldResult val)
#define MINUI(lhs, rhs)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:562
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: GPUToSPIRV.cpp:567
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
Definition: GPUToSPIRV.cpp:589
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Definition: GPUToSPIRV.cpp:594
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:283
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
An attribute that specifies the information regarding the interface variable: descriptor set,...
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
Definition: GPUToSPIRV.cpp:614
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26