MLIR  17.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Debug.h"
23 
24 #include <functional>
25 #include <optional>
26 
27 #define DEBUG_TYPE "mlir-spirv-conversion"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Checks that `candidates` extension requirements are possible to be satisfied
36 /// with the given `targetEnv`.
37 ///
38 /// `candidates` is a vector of vector for extension requirements following
39 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
40 /// convention.
41 template <typename LabelT>
43  LabelT label, const spirv::TargetEnv &targetEnv,
45  for (const auto &ors : candidates) {
46  if (targetEnv.allows(ors))
47  continue;
48 
49  LLVM_DEBUG({
50  SmallVector<StringRef> extStrings;
51  for (spirv::Extension ext : ors)
52  extStrings.push_back(spirv::stringifyExtension(ext));
53 
54  llvm::dbgs() << label << " illegal: requires at least one extension in ["
55  << llvm::join(extStrings, ", ")
56  << "] but none allowed in target environment\n";
57  });
58  return failure();
59  }
60  return success();
61 }
62 
63 /// Checks that `candidates`capability requirements are possible to be satisfied
64 /// with the given `isAllowedFn`.
65 ///
66 /// `candidates` is a vector of vector for capability requirements following
67 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
68 /// convention.
69 template <typename LabelT>
71  LabelT label, const spirv::TargetEnv &targetEnv,
73  for (const auto &ors : candidates) {
74  if (targetEnv.allows(ors))
75  continue;
76 
77  LLVM_DEBUG({
78  SmallVector<StringRef> capStrings;
79  for (spirv::Capability cap : ors)
80  capStrings.push_back(spirv::stringifyCapability(cap));
81 
82  llvm::dbgs() << label << " illegal: requires at least one capability in ["
83  << llvm::join(capStrings, ", ")
84  << "] but none allowed in target environment\n";
85  });
86  return failure();
87  }
88  return success();
89 }
90 
91 /// Returns true if the given `storageClass` needs explicit layout when used in
92 /// Shader environments.
93 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
94  switch (storageClass) {
95  case spirv::StorageClass::PhysicalStorageBuffer:
96  case spirv::StorageClass::PushConstant:
97  case spirv::StorageClass::StorageBuffer:
98  case spirv::StorageClass::Uniform:
99  return true;
100  default:
101  return false;
102  }
103 }
104 
105 /// Wraps the given `elementType` in a struct and gets the pointer to the
106 /// struct. This is used to satisfy Vulkan interface requirements.
107 static spirv::PointerType
108 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
109  auto structType = needsExplicitLayout(storageClass)
110  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
111  : spirv::StructType::get(elementType);
112  return spirv::PointerType::get(structType, storageClass);
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // Type Conversion
117 //===----------------------------------------------------------------------===//
118 
121  return cast<spirv::ScalarType>(
122  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
123 }
124 
126  return ::getIndexType(getContext(), options);
127 }
128 
129 MLIRContext *SPIRVTypeConverter::getContext() const {
130  return targetEnv.getAttr().getContext();
131 }
132 
133 bool SPIRVTypeConverter::allows(spirv::Capability capability) {
134  return targetEnv.allows(capability);
135 }
136 
137 // TODO: This is a utility function that should probably be exposed by the
138 // SPIR-V dialect. Keeping it local till the use case arises.
139 static std::optional<int64_t>
141  if (type.isa<spirv::ScalarType>()) {
142  auto bitWidth = type.getIntOrFloatBitWidth();
143  // According to the SPIR-V spec:
144  // "There is no physical size or bit pattern defined for values with boolean
145  // type. If they are stored (in conjunction with OpVariable), they can only
146  // be used with logical addressing operations, not physical, and only with
147  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
148  // Private, Function, Input, and Output."
149  if (bitWidth == 1)
150  return std::nullopt;
151  return bitWidth / 8;
152  }
153 
154  if (auto complexType = type.dyn_cast<ComplexType>()) {
155  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
156  if (!elementSize)
157  return std::nullopt;
158  return 2 * *elementSize;
159  }
160 
161  if (auto vecType = type.dyn_cast<VectorType>()) {
162  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
163  if (!elementSize)
164  return std::nullopt;
165  return vecType.getNumElements() * *elementSize;
166  }
167 
168  if (auto memRefType = type.dyn_cast<MemRefType>()) {
169  // TODO: Layout should also be controlled by the ABI attributes. For now
170  // using the layout from MemRef.
171  int64_t offset;
172  SmallVector<int64_t, 4> strides;
173  if (!memRefType.hasStaticShape() ||
174  failed(getStridesAndOffset(memRefType, strides, offset)))
175  return std::nullopt;
176 
177  // To get the size of the memref object in memory, the total size is the
178  // max(stride * dimension-size) computed for all dimensions times the size
179  // of the element.
180  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
181  if (!elementSize)
182  return std::nullopt;
183 
184  if (memRefType.getRank() == 0)
185  return elementSize;
186 
187  auto dims = memRefType.getShape();
188  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
189  ShapedType::isDynamic(offset) ||
190  llvm::is_contained(strides, ShapedType::kDynamic))
191  return std::nullopt;
192 
193  int64_t memrefSize = -1;
194  for (const auto &shape : enumerate(dims))
195  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
196 
197  return (offset + memrefSize) * *elementSize;
198  }
199 
200  if (auto tensorType = type.dyn_cast<TensorType>()) {
201  if (!tensorType.hasStaticShape())
202  return std::nullopt;
203 
204  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
205  if (!elementSize)
206  return std::nullopt;
207 
208  int64_t size = *elementSize;
209  for (auto shape : tensorType.getShape())
210  size *= shape;
211 
212  return size;
213  }
214 
215  // TODO: Add size computation for other types.
216  return std::nullopt;
217 }
218 
219 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
220 static Type
223  std::optional<spirv::StorageClass> storageClass = {}) {
224  // Get extension and capability requirements for the given type.
227  type.getExtensions(extensions, storageClass);
228  type.getCapabilities(capabilities, storageClass);
229 
230  // If all requirements are met, then we can accept this type as-is.
231  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
232  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
233  return type;
234 
235  // Otherwise we need to adjust the type, which really means adjusting the
236  // bitwidth given this is a scalar type.
237  if (!options.emulateLT32BitScalarTypes)
238  return nullptr;
239 
240  // We only emulate narrower scalar types here and do not truncate results.
241  if (type.getIntOrFloatBitWidth() > 32) {
242  LLVM_DEBUG(llvm::dbgs()
243  << type
244  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
245  return nullptr;
246  }
247 
248  if (auto floatType = type.dyn_cast<FloatType>()) {
249  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
250  return Builder(targetEnv.getContext()).getF32Type();
251  }
252 
253  auto intType = type.cast<IntegerType>();
254  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
255  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
256  intType.getSignedness());
257 }
258 
259 /// Returns a type with the same shape but with any index element type converted
260 /// to the matching integer type. This is a noop when the element type is not
261 /// the index type.
262 static ShapedType
263 convertIndexElementType(ShapedType type,
265  Type indexType = dyn_cast<IndexType>(type.getElementType());
266  if (!indexType)
267  return type;
268 
269  return type.clone(getIndexType(type.getContext(), options));
270 }
271 
272 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
273 static Type
275  const SPIRVConversionOptions &options, VectorType type,
276  std::optional<spirv::StorageClass> storageClass = {}) {
277  type = cast<VectorType>(convertIndexElementType(type, options));
278  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
279  if (!scalarType) {
280  LLVM_DEBUG(llvm::dbgs()
281  << type << " illegal: cannot convert non-scalar element type\n");
282  return nullptr;
283  }
284 
285  if (type.getRank() <= 1 && type.getNumElements() == 1)
286  return convertScalarType(targetEnv, options, scalarType, storageClass);
287 
288  if (!spirv::CompositeType::isValid(type)) {
289  LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
290  return nullptr;
291  }
292 
293  // Get extension and capability requirements for the given type.
296  type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
297  type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
298 
299  // If all requirements are met, then we can accept this type as-is.
300  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
301  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
302  return type;
303 
304  auto elementType =
305  convertScalarType(targetEnv, options, scalarType, storageClass);
306  if (elementType)
307  return VectorType::get(type.getShape(), elementType);
308  return nullptr;
309 }
310 
311 static Type
313  const SPIRVConversionOptions &options, ComplexType type,
314  std::optional<spirv::StorageClass> storageClass = {}) {
315  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
316  if (!scalarType) {
317  LLVM_DEBUG(llvm::dbgs()
318  << type << " illegal: cannot convert non-scalar element type\n");
319  return nullptr;
320  }
321 
322  auto elementType =
323  convertScalarType(targetEnv, options, scalarType, storageClass);
324  if (!elementType)
325  return nullptr;
326  if (elementType != type.getElementType()) {
327  LLVM_DEBUG(llvm::dbgs()
328  << type << " illegal: complex type emulation unsupported\n");
329  return nullptr;
330  }
331 
332  return VectorType::get(2, elementType);
333 }
334 
335 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
336 ///
337 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
338 /// create composite constants with OpConstantComposite to embed relative large
339 /// constant values and use OpCompositeExtract and OpCompositeInsert to
340 /// manipulate, like what we do for vectors.
341 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
343  TensorType type) {
344  // TODO: Handle dynamic shapes.
345  if (!type.hasStaticShape()) {
346  LLVM_DEBUG(llvm::dbgs()
347  << type << " illegal: dynamic shape unimplemented\n");
348  return nullptr;
349  }
350 
351  type = cast<TensorType>(convertIndexElementType(type, options));
352  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
353  if (!scalarType) {
354  LLVM_DEBUG(llvm::dbgs()
355  << type << " illegal: cannot convert non-scalar element type\n");
356  return nullptr;
357  }
358 
359  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
360  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
361  if (!scalarSize || !tensorSize) {
362  LLVM_DEBUG(llvm::dbgs()
363  << type << " illegal: cannot deduce element count\n");
364  return nullptr;
365  }
366 
367  auto arrayElemCount = *tensorSize / *scalarSize;
368  auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
369  if (!arrayElemType)
370  return nullptr;
371  std::optional<int64_t> arrayElemSize =
372  getTypeNumBytes(options, arrayElemType);
373  if (!arrayElemSize) {
374  LLVM_DEBUG(llvm::dbgs()
375  << type << " illegal: cannot deduce converted element size\n");
376  return nullptr;
377  }
378 
379  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
380 }
381 
384  MemRefType type,
385  spirv::StorageClass storageClass) {
386  unsigned numBoolBits = options.boolNumBits;
387  if (numBoolBits != 8) {
388  LLVM_DEBUG(llvm::dbgs()
389  << "using non-8-bit storage for bool types unimplemented");
390  return nullptr;
391  }
392  auto elementType = IntegerType::get(type.getContext(), numBoolBits)
393  .dyn_cast<spirv::ScalarType>();
394  if (!elementType)
395  return nullptr;
396  Type arrayElemType =
397  convertScalarType(targetEnv, options, elementType, storageClass);
398  if (!arrayElemType)
399  return nullptr;
400  std::optional<int64_t> arrayElemSize =
401  getTypeNumBytes(options, arrayElemType);
402  if (!arrayElemSize) {
403  LLVM_DEBUG(llvm::dbgs()
404  << type << " illegal: cannot deduce converted element size\n");
405  return nullptr;
406  }
407 
408  if (!type.hasStaticShape()) {
409  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
410  // to the element.
411  if (targetEnv.allows(spirv::Capability::Kernel))
412  return spirv::PointerType::get(arrayElemType, storageClass);
413  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
414  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
415  // For Vulkan we need extra wrapping struct and array to satisfy interface
416  // needs.
417  return wrapInStructAndGetPointer(arrayType, storageClass);
418  }
419 
420  int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
421  auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
422  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
423  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
424  if (targetEnv.allows(spirv::Capability::Kernel))
425  return spirv::PointerType::get(arrayType, storageClass);
426  return wrapInStructAndGetPointer(arrayType, storageClass);
427 }
428 
429 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
431  MemRefType type) {
432  auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
433  if (!attr) {
434  LLVM_DEBUG(
435  llvm::dbgs()
436  << type
437  << " illegal: expected memory space to be a SPIR-V storage class "
438  "attribute; please use MemorySpaceToStorageClassConverter to map "
439  "numeric memory spaces beforehand\n");
440  return nullptr;
441  }
442  spirv::StorageClass storageClass = attr.getValue();
443 
444  if (type.getElementType().isa<IntegerType>() &&
445  type.getElementTypeBitWidth() == 1) {
446  return convertBoolMemrefType(targetEnv, options, type, storageClass);
447  }
448 
449  Type arrayElemType;
450  Type elementType = type.getElementType();
451  if (auto vecType = elementType.dyn_cast<VectorType>()) {
452  arrayElemType =
453  convertVectorType(targetEnv, options, vecType, storageClass);
454  } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
455  arrayElemType =
456  convertComplexType(targetEnv, options, complexType, storageClass);
457  } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
458  arrayElemType =
459  convertScalarType(targetEnv, options, scalarType, storageClass);
460  } else if (auto indexType = elementType.dyn_cast<IndexType>()) {
461  type = convertIndexElementType(type, options).cast<MemRefType>();
462  arrayElemType = type.getElementType();
463  } else {
464  LLVM_DEBUG(
465  llvm::dbgs()
466  << type
467  << " unhandled: can only convert scalar or vector element type\n");
468  return nullptr;
469  }
470  if (!arrayElemType)
471  return nullptr;
472 
473  std::optional<int64_t> arrayElemSize =
474  getTypeNumBytes(options, arrayElemType);
475  if (!arrayElemSize) {
476  LLVM_DEBUG(llvm::dbgs()
477  << type << " illegal: cannot deduce converted element size\n");
478  return nullptr;
479  }
480 
481  if (!type.hasStaticShape()) {
482  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
483  // to the element.
484  if (targetEnv.allows(spirv::Capability::Kernel))
485  return spirv::PointerType::get(arrayElemType, storageClass);
486  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
487  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
488  // For Vulkan we need extra wrapping struct and array to satisfy interface
489  // needs.
490  return wrapInStructAndGetPointer(arrayType, storageClass);
491  }
492 
493  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
494  if (!memrefSize) {
495  LLVM_DEBUG(llvm::dbgs()
496  << type << " illegal: cannot deduce element count\n");
497  return nullptr;
498  }
499 
500  auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
501  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
502  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
503  if (targetEnv.allows(spirv::Capability::Kernel))
504  return spirv::PointerType::get(arrayType, storageClass);
505  return wrapInStructAndGetPointer(arrayType, storageClass);
506 }
507 
510  : targetEnv(targetAttr), options(options) {
511  // Add conversions. The order matters here: later ones will be tried earlier.
512 
513  // Allow all SPIR-V dialect specific types. This assumes all builtin types
514  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
515  // were tried before.
516  //
517  // TODO: this assumes that the SPIR-V types are valid to use in
518  // the given target environment, which should be the case if the whole
519  // pipeline is driven by the same target environment. Still, we probably still
520  // want to validate and convert to be safe.
521  addConversion([](spirv::SPIRVType type) { return type; });
522 
523  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
524 
525  addConversion([this](IntegerType intType) -> std::optional<Type> {
526  if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
527  return convertScalarType(this->targetEnv, this->options, scalarType);
528  return Type();
529  });
530 
531  addConversion([this](FloatType floatType) -> std::optional<Type> {
532  if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
533  return convertScalarType(this->targetEnv, this->options, scalarType);
534  return Type();
535  });
536 
537  addConversion([this](ComplexType complexType) {
538  return convertComplexType(this->targetEnv, this->options, complexType);
539  });
540 
541  addConversion([this](VectorType vectorType) {
542  return convertVectorType(this->targetEnv, this->options, vectorType);
543  });
544 
545  addConversion([this](TensorType tensorType) {
546  return convertTensorType(this->targetEnv, this->options, tensorType);
547  });
548 
549  addConversion([this](MemRefType memRefType) {
550  return convertMemrefType(this->targetEnv, this->options, memRefType);
551  });
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // func::FuncOp Conversion Patterns
556 //===----------------------------------------------------------------------===//
557 
558 namespace {
559 /// A pattern for rewriting function signature to convert arguments of functions
560 /// to be of valid SPIR-V types.
561 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
562 public:
564 
566  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
567  ConversionPatternRewriter &rewriter) const override;
568 };
569 } // namespace
570 
572 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
573  ConversionPatternRewriter &rewriter) const {
574  auto fnType = funcOp.getFunctionType();
575  if (fnType.getNumResults() > 1)
576  return failure();
577 
578  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
579  for (const auto &argType : enumerate(fnType.getInputs())) {
580  auto convertedType = getTypeConverter()->convertType(argType.value());
581  if (!convertedType)
582  return failure();
583  signatureConverter.addInputs(argType.index(), convertedType);
584  }
585 
586  Type resultType;
587  if (fnType.getNumResults() == 1) {
588  resultType = getTypeConverter()->convertType(fnType.getResult(0));
589  if (!resultType)
590  return failure();
591  }
592 
593  // Create the converted spirv.func op.
594  auto newFuncOp = rewriter.create<spirv::FuncOp>(
595  funcOp.getLoc(), funcOp.getName(),
596  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
597  resultType ? TypeRange(resultType)
598  : TypeRange()));
599 
600  // Copy over all attributes other than the function name and type.
601  for (const auto &namedAttr : funcOp->getAttrs()) {
602  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
603  namedAttr.getName() != SymbolTable::getSymbolAttrName())
604  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
605  }
606 
607  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
608  newFuncOp.end());
609  if (failed(rewriter.convertRegionTypes(
610  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
611  return failure();
612  rewriter.eraseOp(funcOp);
613  return success();
614 }
615 
617  RewritePatternSet &patterns) {
618  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // Builtin Variables
623 //===----------------------------------------------------------------------===//
624 
625 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
626  spirv::BuiltIn builtin) {
627  // Look through all global variables in the given `body` block and check if
628  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
629  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
630  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
631  spirv::SPIRVDialect::getAttributeName(
632  spirv::Decoration::BuiltIn))) {
633  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
634  if (varBuiltIn && *varBuiltIn == builtin) {
635  return varOp;
636  }
637  }
638  }
639  return nullptr;
640 }
641 
642 /// Gets name of global variable for a builtin.
643 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
644  return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
645 }
646 
647 /// Gets or inserts a global variable for a builtin within `body` block.
648 static spirv::GlobalVariableOp
649 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
650  Type integerType, OpBuilder &builder) {
651  if (auto varOp = getBuiltinVariable(body, builtin))
652  return varOp;
653 
654  OpBuilder::InsertionGuard guard(builder);
655  builder.setInsertionPointToStart(&body);
656 
657  spirv::GlobalVariableOp newVarOp;
658  switch (builtin) {
659  case spirv::BuiltIn::NumWorkgroups:
660  case spirv::BuiltIn::WorkgroupSize:
661  case spirv::BuiltIn::WorkgroupId:
662  case spirv::BuiltIn::LocalInvocationId:
663  case spirv::BuiltIn::GlobalInvocationId: {
664  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
665  spirv::StorageClass::Input);
666  std::string name = getBuiltinVarName(builtin);
667  newVarOp =
668  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
669  break;
670  }
671  case spirv::BuiltIn::SubgroupId:
672  case spirv::BuiltIn::NumSubgroups:
673  case spirv::BuiltIn::SubgroupSize: {
674  auto ptrType =
675  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
676  std::string name = getBuiltinVarName(builtin);
677  newVarOp =
678  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
679  break;
680  }
681  default:
682  emitError(loc, "unimplemented builtin variable generation for ")
683  << stringifyBuiltIn(builtin);
684  }
685  return newVarOp;
686 }
687 
689  spirv::BuiltIn builtin,
690  Type integerType,
691  OpBuilder &builder) {
693  if (!parent) {
694  op->emitError("expected operation to be within a module-like op");
695  return nullptr;
696  }
697 
698  spirv::GlobalVariableOp varOp =
699  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
700  builtin, integerType, builder);
701  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
702  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // Push constant storage
707 //===----------------------------------------------------------------------===//
708 
709 /// Returns the pointer type for the push constant storage containing
710 /// `elementCount` 32-bit integer values.
711 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
712  Builder &builder,
713  Type indexType) {
714  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
715  /*stride=*/4);
716  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
717  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
718 }
719 
720 /// Returns the push constant varible containing `elementCount` 32-bit integer
721 /// values in `body`. Returns null op if such an op does not exit.
722 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
723  unsigned elementCount) {
724  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
725  auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
726  if (!ptrType)
727  continue;
728 
729  // Note that Vulkan requires "There must be no more than one push constant
730  // block statically used per shader entry point." So we should always reuse
731  // the existing one.
732  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
733  auto numElements = ptrType.getPointeeType()
735  .getElementType(0)
737  .getNumElements();
738  if (numElements == elementCount)
739  return varOp;
740  }
741  }
742  return nullptr;
743 }
744 
745 /// Gets or inserts a global variable for push constant storage containing
746 /// `elementCount` 32-bit integer values in `block`.
747 static spirv::GlobalVariableOp
749  unsigned elementCount, OpBuilder &b,
750  Type indexType) {
751  if (auto varOp = getPushConstantVariable(block, elementCount))
752  return varOp;
753 
754  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
755  auto type = getPushConstantStorageType(elementCount, builder, indexType);
756  const char *name = "__push_constant_var__";
757  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
758  /*initializer=*/nullptr);
759 }
760 
761 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
762  unsigned offset, Type integerType,
763  OpBuilder &builder) {
764  Location loc = op->getLoc();
766  if (!parent) {
767  op->emitError("expected operation to be within a module-like op");
768  return nullptr;
769  }
770 
771  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
772  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
773 
774  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
775  Value offsetOp = builder.create<spirv::ConstantOp>(
776  loc, integerType, builder.getI32IntegerAttr(offset));
777  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
778  auto acOp = builder.create<spirv::AccessChainOp>(
779  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
780  return builder.create<spirv::LoadOp>(loc, acOp);
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Index calculation
785 //===----------------------------------------------------------------------===//
786 
788  int64_t offset, Type integerType,
789  Location loc, OpBuilder &builder) {
790  assert(indices.size() == strides.size() &&
791  "must provide indices for all dimensions");
792 
793  // TODO: Consider moving to use affine.apply and patterns converting
794  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
795  // broken down into progressive small steps so we can have intermediate steps
796  // using other dialects. At the moment SPIR-V is the final sink.
797 
798  Value linearizedIndex = builder.create<spirv::ConstantOp>(
799  loc, integerType, IntegerAttr::get(integerType, offset));
800  for (const auto &index : llvm::enumerate(indices)) {
801  Value strideVal = builder.create<spirv::ConstantOp>(
802  loc, integerType,
803  IntegerAttr::get(integerType, strides[index.index()]));
804  Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
805  linearizedIndex =
806  builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
807  }
808  return linearizedIndex;
809 }
810 
812  MemRefType baseType, Value basePtr,
813  ValueRange indices, Location loc,
814  OpBuilder &builder) {
815  // Get base and offset of the MemRefType and verify they are static.
816 
817  int64_t offset;
818  SmallVector<int64_t, 4> strides;
819  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
820  llvm::is_contained(strides, ShapedType::kDynamic) ||
821  ShapedType::isDynamic(offset)) {
822  return nullptr;
823  }
824 
825  auto indexType = typeConverter.getIndexType();
826 
827  SmallVector<Value, 2> linearizedIndices;
828  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
829 
830  // Add a '0' at the start to index into the struct.
831  linearizedIndices.push_back(zero);
832 
833  if (baseType.getRank() == 0) {
834  linearizedIndices.push_back(zero);
835  } else {
836  linearizedIndices.push_back(
837  linearizeIndex(indices, strides, offset, indexType, loc, builder));
838  }
839  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
840 }
841 
843  MemRefType baseType, Value basePtr,
844  ValueRange indices, Location loc,
845  OpBuilder &builder) {
846  // Get base and offset of the MemRefType and verify they are static.
847 
848  int64_t offset;
849  SmallVector<int64_t, 4> strides;
850  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
851  llvm::is_contained(strides, ShapedType::kDynamic) ||
852  ShapedType::isDynamic(offset)) {
853  return nullptr;
854  }
855 
856  auto indexType = typeConverter.getIndexType();
857 
858  SmallVector<Value, 2> linearizedIndices;
859  Value linearIndex;
860  if (baseType.getRank() == 0) {
861  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
862  } else {
863  linearIndex =
864  linearizeIndex(indices, strides, offset, indexType, loc, builder);
865  }
866  Type pointeeType =
867  basePtr.getType().cast<spirv::PointerType>().getPointeeType();
868  if (pointeeType.isa<spirv::ArrayType>()) {
869  linearizedIndices.push_back(linearIndex);
870  return builder.create<spirv::AccessChainOp>(loc, basePtr,
871  linearizedIndices);
872  }
873  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
874  linearizedIndices);
875 }
876 
878  MemRefType baseType, Value basePtr,
879  ValueRange indices, Location loc,
880  OpBuilder &builder) {
881 
882  if (typeConverter.allows(spirv::Capability::Kernel)) {
883  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
884  builder);
885  }
886 
887  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
888  builder);
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // SPIR-V ConversionTarget
893 //===----------------------------------------------------------------------===//
894 
895 std::unique_ptr<SPIRVConversionTarget>
897  std::unique_ptr<SPIRVConversionTarget> target(
898  // std::make_unique does not work here because the constructor is private.
899  new SPIRVConversionTarget(targetAttr));
900  SPIRVConversionTarget *targetPtr = target.get();
901  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
902  // We need to capture the raw pointer here because it is stable:
903  // target will be destroyed once this function is returned.
904  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
905  return target;
906 }
907 
908 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
909  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
910 
911 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
912  // Make sure this op is available at the given version. Ops not implementing
913  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
914  // SPIR-V versions.
915  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
916  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
917  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
918  LLVM_DEBUG(llvm::dbgs()
919  << op->getName() << " illegal: requiring min version "
920  << spirv::stringifyVersion(*minVersion) << "\n");
921  return false;
922  }
923  }
924  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
925  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
926  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
927  LLVM_DEBUG(llvm::dbgs()
928  << op->getName() << " illegal: requiring max version "
929  << spirv::stringifyVersion(*maxVersion) << "\n");
930  return false;
931  }
932  }
933 
934  // Make sure this op's required extensions are allowed to use. Ops not
935  // implementing QueryExtensionInterface do not require extensions to be
936  // available.
937  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
938  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
939  extensions.getExtensions())))
940  return false;
941 
942  // Make sure this op's required extensions are allowed to use. Ops not
943  // implementing QueryCapabilityInterface do not require capabilities to be
944  // available.
945  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
946  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
947  capabilities.getCapabilities())))
948  return false;
949 
950  SmallVector<Type, 4> valueTypes;
951  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
952  valueTypes.append(op->result_type_begin(), op->result_type_end());
953 
954  // Ensure that all types have been converted to SPIRV types.
955  if (llvm::any_of(valueTypes,
956  [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
957  return false;
958 
959  // Special treatment for global variables, whose type requirements are
960  // conveyed by type attributes.
961  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
962  valueTypes.push_back(globalVar.getType());
963 
964  // Make sure the op's operands/results use types that are allowed by the
965  // target environment.
966  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
967  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
968  for (Type valueType : valueTypes) {
969  typeExtensions.clear();
970  valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
971  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
972  typeExtensions)))
973  return false;
974 
975  typeCapabilities.clear();
976  valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
977  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
978  typeCapabilities)))
979  return false;
980  }
981 
982  return true;
983 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool needsExplicitLayout(spirv::StorageClass storageClass)
Returns true if the given storageClass needs explicit layout when used in Shader environments.
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount)
Returns the push constant varible containing elementCount 32-bit integer values in body.
static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, TensorType type)
Converts a tensor type to a suitable type under the given targetEnv.
static std::string getBuiltinVarName(spirv::BuiltIn builtin)
Gets name of global variable for a builtin.
static LogicalResult checkCapabilityRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates)
Checks that candidatescapability requirements are possible to be satisfied with the given isAllowedFn...
static std::optional< int64_t > getTypeNumBytes(const SPIRVConversionOptions &options, Type type)
static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin)
static ShapedType convertIndexElementType(ShapedType type, const SPIRVConversionOptions &options)
Returns a type with the same shape but with any index element type converted to the matching integer ...
static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType)
Gets or inserts a global variable for push constant storage containing elementCount 32-bit integer va...
static Type convertComplexType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, ComplexType type, std::optional< spirv::StorageClass > storageClass={})
static LogicalResult checkExtensionRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv.
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, spirv::ScalarType type, std::optional< spirv::StorageClass > storageClass={})
Converts a scalar type to a suitable type under the given targetEnv.
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass)
static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional< spirv::StorageClass > storageClass={})
Converts a vector type to a suitable type under the given targetEnv.
static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass)
Wraps the given elementType in a struct and gets the pointer to the struct.
static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType)
Returns the pointer type for the push constant storage containing elementCount 32-bit integer values.
static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type)
static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder)
Gets or inserts a global variable for a builtin within body block.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1223
Block represents an ordered list of Operations.
Definition: Block.h:30
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:182
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:202
FloatType getF32Type()
Definition: Builders.cpp:60
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:93
This class implements a pattern rewriter for use with ConversionPatterns.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:234
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:301
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
operand_type_iterator operand_type_end()
Definition: Operation.h:375
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
result_type_iterator result_type_end()
Definition: Operation.h:406
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:218
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:550
result_type_iterator result_type_begin()
Definition: Operation.h:405
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:457
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
operand_type_iterator operand_type_begin()
Definition: Operation.h:374
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability)
Checks if the SPIR-V capability inquired is supported.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:80
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast_or_null() const
Definition: Types.h:316
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:51
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:100
Type getPointeeType() const
Definition: SPIRVTypes.cpp:482
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:478
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:535
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:586
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:617
SPIR-V struct type.
Definition: SPIRVTypes.h:282
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder)
Returns the value for the given builtin variable.
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26