MLIR  16.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Debug.h"
23 
24 #include <functional>
25 
26 #define DEBUG_TYPE "mlir-spirv-conversion"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Checks that `candidates` extension requirements are possible to be satisfied
35 /// with the given `targetEnv`.
36 ///
37 /// `candidates` is a vector of vector for extension requirements following
38 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
39 /// convention.
40 template <typename LabelT>
42  LabelT label, const spirv::TargetEnv &targetEnv,
44  for (const auto &ors : candidates) {
45  if (targetEnv.allows(ors))
46  continue;
47 
48  LLVM_DEBUG({
49  SmallVector<StringRef> extStrings;
50  for (spirv::Extension ext : ors)
51  extStrings.push_back(spirv::stringifyExtension(ext));
52 
53  llvm::dbgs() << label << " illegal: requires at least one extension in ["
54  << llvm::join(extStrings, ", ")
55  << "] but none allowed in target environment\n";
56  });
57  return failure();
58  }
59  return success();
60 }
61 
62 /// Checks that `candidates`capability requirements are possible to be satisfied
63 /// with the given `isAllowedFn`.
64 ///
65 /// `candidates` is a vector of vector for capability requirements following
66 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
67 /// convention.
68 template <typename LabelT>
70  LabelT label, const spirv::TargetEnv &targetEnv,
72  for (const auto &ors : candidates) {
73  if (targetEnv.allows(ors))
74  continue;
75 
76  LLVM_DEBUG({
77  SmallVector<StringRef> capStrings;
78  for (spirv::Capability cap : ors)
79  capStrings.push_back(spirv::stringifyCapability(cap));
80 
81  llvm::dbgs() << label << " illegal: requires at least one capability in ["
82  << llvm::join(capStrings, ", ")
83  << "] but none allowed in target environment\n";
84  });
85  return failure();
86  }
87  return success();
88 }
89 
90 /// Returns true if the given `storageClass` needs explicit layout when used in
91 /// Shader environments.
92 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
93  switch (storageClass) {
94  case spirv::StorageClass::PhysicalStorageBuffer:
95  case spirv::StorageClass::PushConstant:
96  case spirv::StorageClass::StorageBuffer:
97  case spirv::StorageClass::Uniform:
98  return true;
99  default:
100  return false;
101  }
102 }
103 
104 /// Wraps the given `elementType` in a struct and gets the pointer to the
105 /// struct. This is used to satisfy Vulkan interface requirements.
106 static spirv::PointerType
107 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
108  auto structType = needsExplicitLayout(storageClass)
109  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
110  : spirv::StructType::get(elementType);
111  return spirv::PointerType::get(structType, storageClass);
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Type Conversion
116 //===----------------------------------------------------------------------===//
117 
119  return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
120 }
121 
122 MLIRContext *SPIRVTypeConverter::getContext() const {
123  return targetEnv.getAttr().getContext();
124 }
125 
126 bool SPIRVTypeConverter::allows(spirv::Capability capability) {
127  return targetEnv.allows(capability);
128 }
129 
130 // TODO: This is a utility function that should probably be exposed by the
131 // SPIR-V dialect. Keeping it local till the use case arises.
133  Type type) {
134  if (type.isa<spirv::ScalarType>()) {
135  auto bitWidth = type.getIntOrFloatBitWidth();
136  // According to the SPIR-V spec:
137  // "There is no physical size or bit pattern defined for values with boolean
138  // type. If they are stored (in conjunction with OpVariable), they can only
139  // be used with logical addressing operations, not physical, and only with
140  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
141  // Private, Function, Input, and Output."
142  if (bitWidth == 1)
143  return llvm::None;
144  return bitWidth / 8;
145  }
146 
147  if (auto vecType = type.dyn_cast<VectorType>()) {
148  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
149  if (!elementSize)
150  return llvm::None;
151  return vecType.getNumElements() * *elementSize;
152  }
153 
154  if (auto memRefType = type.dyn_cast<MemRefType>()) {
155  // TODO: Layout should also be controlled by the ABI attributes. For now
156  // using the layout from MemRef.
157  int64_t offset;
158  SmallVector<int64_t, 4> strides;
159  if (!memRefType.hasStaticShape() ||
160  failed(getStridesAndOffset(memRefType, strides, offset)))
161  return llvm::None;
162 
163  // To get the size of the memref object in memory, the total size is the
164  // max(stride * dimension-size) computed for all dimensions times the size
165  // of the element.
166  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
167  if (!elementSize)
168  return llvm::None;
169 
170  if (memRefType.getRank() == 0)
171  return elementSize;
172 
173  auto dims = memRefType.getShape();
174  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
175  ShapedType::isDynamic(offset) ||
176  llvm::is_contained(strides, ShapedType::kDynamic))
177  return llvm::None;
178 
179  int64_t memrefSize = -1;
180  for (const auto &shape : enumerate(dims))
181  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
182 
183  return (offset + memrefSize) * *elementSize;
184  }
185 
186  if (auto tensorType = type.dyn_cast<TensorType>()) {
187  if (!tensorType.hasStaticShape())
188  return llvm::None;
189 
190  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
191  if (!elementSize)
192  return llvm::None;
193 
194  int64_t size = *elementSize;
195  for (auto shape : tensorType.getShape())
196  size *= shape;
197 
198  return size;
199  }
200 
201  // TODO: Add size computation for other types.
202  return llvm::None;
203 }
204 
205 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
206 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
208  spirv::ScalarType type,
209  Optional<spirv::StorageClass> storageClass = {}) {
210  // Get extension and capability requirements for the given type.
213  type.getExtensions(extensions, storageClass);
214  type.getCapabilities(capabilities, storageClass);
215 
216  // If all requirements are met, then we can accept this type as-is.
217  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
218  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
219  return type;
220 
221  // Otherwise we need to adjust the type, which really means adjusting the
222  // bitwidth given this is a scalar type.
223  if (!options.emulateLT32BitScalarTypes)
224  return nullptr;
225 
226  // We only emulate narrower scalar types here and do not truncate results.
227  if (type.getIntOrFloatBitWidth() > 32) {
228  LLVM_DEBUG(llvm::dbgs()
229  << type
230  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
231  return nullptr;
232  }
233 
234  if (auto floatType = type.dyn_cast<FloatType>()) {
235  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
236  return Builder(targetEnv.getContext()).getF32Type();
237  }
238 
239  auto intType = type.cast<IntegerType>();
240  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
241  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
242  intType.getSignedness());
243 }
244 
245 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
246 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
248  VectorType type,
249  Optional<spirv::StorageClass> storageClass = {}) {
250  auto scalarType = type.getElementType().cast<spirv::ScalarType>();
251  if (type.getRank() <= 1 && type.getNumElements() == 1)
252  return convertScalarType(targetEnv, options, scalarType, storageClass);
253 
254  if (!spirv::CompositeType::isValid(type)) {
255  // TODO: Vector types with more than four elements can be translated into
256  // array types.
257  LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
258  return nullptr;
259  }
260 
261  // Get extension and capability requirements for the given type.
264  type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
265  type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
266 
267  // If all requirements are met, then we can accept this type as-is.
268  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
269  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
270  return type;
271 
272  auto elementType =
273  convertScalarType(targetEnv, options, scalarType, storageClass);
274  if (elementType)
275  return VectorType::get(type.getShape(), elementType);
276  return nullptr;
277 }
278 
279 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
280 ///
281 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
282 /// create composite constants with OpConstantComposite to embed relative large
283 /// constant values and use OpCompositeExtract and OpCompositeInsert to
284 /// manipulate, like what we do for vectors.
285 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
287  TensorType type) {
288  // TODO: Handle dynamic shapes.
289  if (!type.hasStaticShape()) {
290  LLVM_DEBUG(llvm::dbgs()
291  << type << " illegal: dynamic shape unimplemented\n");
292  return nullptr;
293  }
294 
295  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
296  if (!scalarType) {
297  LLVM_DEBUG(llvm::dbgs()
298  << type << " illegal: cannot convert non-scalar element type\n");
299  return nullptr;
300  }
301 
302  Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
303  Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
304  if (!scalarSize || !tensorSize) {
305  LLVM_DEBUG(llvm::dbgs()
306  << type << " illegal: cannot deduce element count\n");
307  return nullptr;
308  }
309 
310  auto arrayElemCount = *tensorSize / *scalarSize;
311  auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
312  if (!arrayElemType)
313  return nullptr;
314  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
315  if (!arrayElemSize) {
316  LLVM_DEBUG(llvm::dbgs()
317  << type << " illegal: cannot deduce converted element size\n");
318  return nullptr;
319  }
320 
321  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
322 }
323 
326  MemRefType type,
327  spirv::StorageClass storageClass) {
328  unsigned numBoolBits = options.boolNumBits;
329  if (numBoolBits != 8) {
330  LLVM_DEBUG(llvm::dbgs()
331  << "using non-8-bit storage for bool types unimplemented");
332  return nullptr;
333  }
334  auto elementType = IntegerType::get(type.getContext(), numBoolBits)
335  .dyn_cast<spirv::ScalarType>();
336  if (!elementType)
337  return nullptr;
338  Type arrayElemType =
339  convertScalarType(targetEnv, options, elementType, storageClass);
340  if (!arrayElemType)
341  return nullptr;
342  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
343  if (!arrayElemSize) {
344  LLVM_DEBUG(llvm::dbgs()
345  << type << " illegal: cannot deduce converted element size\n");
346  return nullptr;
347  }
348 
349 
350  if (!type.hasStaticShape()) {
351  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
352  // to the element.
353  if (targetEnv.allows(spirv::Capability::Kernel))
354  return spirv::PointerType::get(arrayElemType, storageClass);
355  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
356  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
357  // For Vulkan we need extra wrapping struct and array to satisfy interface
358  // needs.
359  return wrapInStructAndGetPointer(arrayType, storageClass);
360  }
361 
362  int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
363  auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
364  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
365  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
366  if (targetEnv.allows(spirv::Capability::Kernel))
367  return spirv::PointerType::get(arrayType, storageClass);
368  return wrapInStructAndGetPointer(arrayType, storageClass);
369 }
370 
371 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
373  MemRefType type) {
374  auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
375  if (!attr) {
376  LLVM_DEBUG(
377  llvm::dbgs()
378  << type
379  << " illegal: expected memory space to be a SPIR-V storage class "
380  "attribute; please use MemorySpaceToStorageClassConverter to map "
381  "numeric memory spaces beforehand\n");
382  return nullptr;
383  }
384  spirv::StorageClass storageClass = attr.getValue();
385 
386  if (type.getElementType().isa<IntegerType>() &&
387  type.getElementTypeBitWidth() == 1) {
388  return convertBoolMemrefType(targetEnv, options, type, storageClass);
389  }
390 
391  Type arrayElemType;
392  Type elementType = type.getElementType();
393  if (auto vecType = elementType.dyn_cast<VectorType>()) {
394  arrayElemType =
395  convertVectorType(targetEnv, options, vecType, storageClass);
396  } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
397  arrayElemType =
398  convertScalarType(targetEnv, options, scalarType, storageClass);
399  } else {
400  LLVM_DEBUG(
401  llvm::dbgs()
402  << type
403  << " unhandled: can only convert scalar or vector element type\n");
404  return nullptr;
405  }
406  if (!arrayElemType)
407  return nullptr;
408 
409  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
410  if (!arrayElemSize) {
411  LLVM_DEBUG(llvm::dbgs()
412  << type << " illegal: cannot deduce converted element size\n");
413  return nullptr;
414  }
415 
416 
417  if (!type.hasStaticShape()) {
418  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
419  // to the element.
420  if (targetEnv.allows(spirv::Capability::Kernel))
421  return spirv::PointerType::get(arrayElemType, storageClass);
422  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
423  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
424  // For Vulkan we need extra wrapping struct and array to satisfy interface
425  // needs.
426  return wrapInStructAndGetPointer(arrayType, storageClass);
427  }
428 
429  Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
430  if (!memrefSize) {
431  LLVM_DEBUG(llvm::dbgs()
432  << type << " illegal: cannot deduce element count\n");
433  return nullptr;
434  }
435 
436  auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
437  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
438  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
439  if (targetEnv.allows(spirv::Capability::Kernel))
440  return spirv::PointerType::get(arrayType, storageClass);
441  return wrapInStructAndGetPointer(arrayType, storageClass);
442 }
443 
446  : targetEnv(targetAttr), options(options) {
447  // Add conversions. The order matters here: later ones will be tried earlier.
448 
449  // Allow all SPIR-V dialect specific types. This assumes all builtin types
450  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
451  // were tried before.
452  //
453  // TODO: this assumes that the SPIR-V types are valid to use in
454  // the given target environment, which should be the case if the whole
455  // pipeline is driven by the same target environment. Still, we probably still
456  // want to validate and convert to be safe.
457  addConversion([](spirv::SPIRVType type) { return type; });
458 
459  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
460 
461  addConversion([this](IntegerType intType) -> Optional<Type> {
462  if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
463  return convertScalarType(this->targetEnv, this->options, scalarType);
464  return Type();
465  });
466 
467  addConversion([this](FloatType floatType) -> Optional<Type> {
468  if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
469  return convertScalarType(this->targetEnv, this->options, scalarType);
470  return Type();
471  });
472 
473  addConversion([this](VectorType vectorType) {
474  return convertVectorType(this->targetEnv, this->options, vectorType);
475  });
476 
477  addConversion([this](TensorType tensorType) {
478  return convertTensorType(this->targetEnv, this->options, tensorType);
479  });
480 
481  addConversion([this](MemRefType memRefType) {
482  return convertMemrefType(this->targetEnv, this->options, memRefType);
483  });
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // func::FuncOp Conversion Patterns
488 //===----------------------------------------------------------------------===//
489 
490 namespace {
491 /// A pattern for rewriting function signature to convert arguments of functions
492 /// to be of valid SPIR-V types.
493 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
494 public:
496 
498  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
499  ConversionPatternRewriter &rewriter) const override;
500 };
501 } // namespace
502 
504 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
505  ConversionPatternRewriter &rewriter) const {
506  auto fnType = funcOp.getFunctionType();
507  if (fnType.getNumResults() > 1)
508  return failure();
509 
510  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
511  for (const auto &argType : enumerate(fnType.getInputs())) {
512  auto convertedType = getTypeConverter()->convertType(argType.value());
513  if (!convertedType)
514  return failure();
515  signatureConverter.addInputs(argType.index(), convertedType);
516  }
517 
518  Type resultType;
519  if (fnType.getNumResults() == 1) {
520  resultType = getTypeConverter()->convertType(fnType.getResult(0));
521  if (!resultType)
522  return failure();
523  }
524 
525  // Create the converted spirv.func op.
526  auto newFuncOp = rewriter.create<spirv::FuncOp>(
527  funcOp.getLoc(), funcOp.getName(),
528  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
529  resultType ? TypeRange(resultType)
530  : TypeRange()));
531 
532  // Copy over all attributes other than the function name and type.
533  for (const auto &namedAttr : funcOp->getAttrs()) {
534  if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
535  namedAttr.getName() != SymbolTable::getSymbolAttrName())
536  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
537  }
538 
539  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
540  newFuncOp.end());
541  if (failed(rewriter.convertRegionTypes(
542  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
543  return failure();
544  rewriter.eraseOp(funcOp);
545  return success();
546 }
547 
549  RewritePatternSet &patterns) {
550  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Builtin Variables
555 //===----------------------------------------------------------------------===//
556 
557 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
558  spirv::BuiltIn builtin) {
559  // Look through all global variables in the given `body` block and check if
560  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
561  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
562  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
563  spirv::SPIRVDialect::getAttributeName(
564  spirv::Decoration::BuiltIn))) {
565  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
566  if (varBuiltIn && *varBuiltIn == builtin) {
567  return varOp;
568  }
569  }
570  }
571  return nullptr;
572 }
573 
574 /// Gets name of global variable for a builtin.
575 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
576  return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
577 }
578 
579 /// Gets or inserts a global variable for a builtin within `body` block.
580 static spirv::GlobalVariableOp
581 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
582  Type integerType, OpBuilder &builder) {
583  if (auto varOp = getBuiltinVariable(body, builtin))
584  return varOp;
585 
586  OpBuilder::InsertionGuard guard(builder);
587  builder.setInsertionPointToStart(&body);
588 
589  spirv::GlobalVariableOp newVarOp;
590  switch (builtin) {
591  case spirv::BuiltIn::NumWorkgroups:
592  case spirv::BuiltIn::WorkgroupSize:
593  case spirv::BuiltIn::WorkgroupId:
594  case spirv::BuiltIn::LocalInvocationId:
595  case spirv::BuiltIn::GlobalInvocationId: {
596  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
597  spirv::StorageClass::Input);
598  std::string name = getBuiltinVarName(builtin);
599  newVarOp =
600  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
601  break;
602  }
603  case spirv::BuiltIn::SubgroupId:
604  case spirv::BuiltIn::NumSubgroups:
605  case spirv::BuiltIn::SubgroupSize: {
606  auto ptrType =
607  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
608  std::string name = getBuiltinVarName(builtin);
609  newVarOp =
610  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
611  break;
612  }
613  default:
614  emitError(loc, "unimplemented builtin variable generation for ")
615  << stringifyBuiltIn(builtin);
616  }
617  return newVarOp;
618 }
619 
621  spirv::BuiltIn builtin,
622  Type integerType,
623  OpBuilder &builder) {
625  if (!parent) {
626  op->emitError("expected operation to be within a module-like op");
627  return nullptr;
628  }
629 
630  spirv::GlobalVariableOp varOp =
631  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
632  builtin, integerType, builder);
633  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
634  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // Push constant storage
639 //===----------------------------------------------------------------------===//
640 
641 /// Returns the pointer type for the push constant storage containing
642 /// `elementCount` 32-bit integer values.
643 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
644  Builder &builder,
645  Type indexType) {
646  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
647  /*stride=*/4);
648  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
649  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
650 }
651 
652 /// Returns the push constant varible containing `elementCount` 32-bit integer
653 /// values in `body`. Returns null op if such an op does not exit.
654 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
655  unsigned elementCount) {
656  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
657  auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
658  if (!ptrType)
659  continue;
660 
661  // Note that Vulkan requires "There must be no more than one push constant
662  // block statically used per shader entry point." So we should always reuse
663  // the existing one.
664  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
665  auto numElements = ptrType.getPointeeType()
667  .getElementType(0)
669  .getNumElements();
670  if (numElements == elementCount)
671  return varOp;
672  }
673  }
674  return nullptr;
675 }
676 
677 /// Gets or inserts a global variable for push constant storage containing
678 /// `elementCount` 32-bit integer values in `block`.
679 static spirv::GlobalVariableOp
681  unsigned elementCount, OpBuilder &b,
682  Type indexType) {
683  if (auto varOp = getPushConstantVariable(block, elementCount))
684  return varOp;
685 
686  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
687  auto type = getPushConstantStorageType(elementCount, builder, indexType);
688  const char *name = "__push_constant_var__";
689  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
690  /*initializer=*/nullptr);
691 }
692 
693 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
694  unsigned offset, Type integerType,
695  OpBuilder &builder) {
696  Location loc = op->getLoc();
698  if (!parent) {
699  op->emitError("expected operation to be within a module-like op");
700  return nullptr;
701  }
702 
703  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
704  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
705 
706  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
707  Value offsetOp = builder.create<spirv::ConstantOp>(
708  loc, integerType, builder.getI32IntegerAttr(offset));
709  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
710  auto acOp = builder.create<spirv::AccessChainOp>(
711  loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
712  return builder.create<spirv::LoadOp>(loc, acOp);
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Index calculation
717 //===----------------------------------------------------------------------===//
718 
720  int64_t offset, Type integerType,
721  Location loc, OpBuilder &builder) {
722  assert(indices.size() == strides.size() &&
723  "must provide indices for all dimensions");
724 
725  // TODO: Consider moving to use affine.apply and patterns converting
726  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
727  // broken down into progressive small steps so we can have intermediate steps
728  // using other dialects. At the moment SPIR-V is the final sink.
729 
730  Value linearizedIndex = builder.create<spirv::ConstantOp>(
731  loc, integerType, IntegerAttr::get(integerType, offset));
732  for (const auto &index : llvm::enumerate(indices)) {
733  Value strideVal = builder.create<spirv::ConstantOp>(
734  loc, integerType,
735  IntegerAttr::get(integerType, strides[index.index()]));
736  Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
737  linearizedIndex =
738  builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
739  }
740  return linearizedIndex;
741 }
742 
744  MemRefType baseType, Value basePtr,
745  ValueRange indices, Location loc,
746  OpBuilder &builder) {
747  // Get base and offset of the MemRefType and verify they are static.
748 
749  int64_t offset;
750  SmallVector<int64_t, 4> strides;
751  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
752  llvm::is_contained(strides, ShapedType::kDynamic) ||
753  ShapedType::isDynamic(offset)) {
754  return nullptr;
755  }
756 
757  auto indexType = typeConverter.getIndexType();
758 
759  SmallVector<Value, 2> linearizedIndices;
760  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
761 
762  // Add a '0' at the start to index into the struct.
763  linearizedIndices.push_back(zero);
764 
765  if (baseType.getRank() == 0) {
766  linearizedIndices.push_back(zero);
767  } else {
768  linearizedIndices.push_back(
769  linearizeIndex(indices, strides, offset, indexType, loc, builder));
770  }
771  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
772 }
773 
775  MemRefType baseType, Value basePtr,
776  ValueRange indices, Location loc,
777  OpBuilder &builder) {
778  // Get base and offset of the MemRefType and verify they are static.
779 
780  int64_t offset;
781  SmallVector<int64_t, 4> strides;
782  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
783  llvm::is_contained(strides, ShapedType::kDynamic) ||
784  ShapedType::isDynamic(offset)) {
785  return nullptr;
786  }
787 
788  auto indexType = typeConverter.getIndexType();
789 
790  SmallVector<Value, 2> linearizedIndices;
791  Value linearIndex;
792  if (baseType.getRank() == 0) {
793  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
794  } else {
795  linearIndex =
796  linearizeIndex(indices, strides, offset, indexType, loc, builder);
797  }
798  Type pointeeType =
799  basePtr.getType().cast<spirv::PointerType>().getPointeeType();
800  if (pointeeType.isa<spirv::ArrayType>()) {
801  linearizedIndices.push_back(linearIndex);
802  return builder.create<spirv::AccessChainOp>(loc, basePtr,
803  linearizedIndices);
804  }
805  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
806  linearizedIndices);
807 }
808 
810  MemRefType baseType, Value basePtr,
811  ValueRange indices, Location loc,
812  OpBuilder &builder) {
813 
814  if (typeConverter.allows(spirv::Capability::Kernel)) {
815  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
816  builder);
817  }
818 
819  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
820  builder);
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // SPIR-V ConversionTarget
825 //===----------------------------------------------------------------------===//
826 
827 std::unique_ptr<SPIRVConversionTarget>
829  std::unique_ptr<SPIRVConversionTarget> target(
830  // std::make_unique does not work here because the constructor is private.
831  new SPIRVConversionTarget(targetAttr));
832  SPIRVConversionTarget *targetPtr = target.get();
833  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
834  // We need to capture the raw pointer here because it is stable:
835  // target will be destroyed once this function is returned.
836  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
837  return target;
838 }
839 
840 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
841  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
842 
843 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
844  // Make sure this op is available at the given version. Ops not implementing
845  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
846  // SPIR-V versions.
847  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
848  Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
849  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
850  LLVM_DEBUG(llvm::dbgs()
851  << op->getName() << " illegal: requiring min version "
852  << spirv::stringifyVersion(*minVersion) << "\n");
853  return false;
854  }
855  }
856  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
857  Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
858  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
859  LLVM_DEBUG(llvm::dbgs()
860  << op->getName() << " illegal: requiring max version "
861  << spirv::stringifyVersion(*maxVersion) << "\n");
862  return false;
863  }
864  }
865 
866  // Make sure this op's required extensions are allowed to use. Ops not
867  // implementing QueryExtensionInterface do not require extensions to be
868  // available.
869  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
870  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
871  extensions.getExtensions())))
872  return false;
873 
874  // Make sure this op's required extensions are allowed to use. Ops not
875  // implementing QueryCapabilityInterface do not require capabilities to be
876  // available.
877  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
878  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
879  capabilities.getCapabilities())))
880  return false;
881 
882  SmallVector<Type, 4> valueTypes;
883  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
884  valueTypes.append(op->result_type_begin(), op->result_type_end());
885 
886  // Ensure that all types have been converted to SPIRV types.
887  if (llvm::any_of(valueTypes,
888  [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
889  return false;
890 
891  // Special treatment for global variables, whose type requirements are
892  // conveyed by type attributes.
893  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
894  valueTypes.push_back(globalVar.getType());
895 
896  // Make sure the op's operands/results use types that are allowed by the
897  // target environment.
898  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
899  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
900  for (Type valueType : valueTypes) {
901  typeExtensions.clear();
902  valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
903  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
904  typeExtensions)))
905  return false;
906 
907  typeCapabilities.clear();
908  valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
909  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
910  typeCapabilities)))
911  return false;
912  }
913 
914  return true;
915 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool needsExplicitLayout(spirv::StorageClass storageClass)
Returns true if the given storageClass needs explicit layout when used in Shader environments.
static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, Optional< spirv::StorageClass > storageClass={})
Converts a vector type to a suitable type under the given targetEnv.
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount)
Returns the push constant varible containing elementCount 32-bit integer values in body.
static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, TensorType type)
Converts a tensor type to a suitable type under the given targetEnv.
static std::string getBuiltinVarName(spirv::BuiltIn builtin)
Gets name of global variable for a builtin.
static LogicalResult checkCapabilityRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates)
Checks that candidatescapability requirements are possible to be satisfied with the given isAllowedFn...
static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin)
static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType)
Gets or inserts a global variable for push constant storage containing elementCount 32-bit integer va...
static LogicalResult checkExtensionRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv.
static Optional< int64_t > getTypeNumBytes(const SPIRVConversionOptions &options, Type type)
static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, spirv::ScalarType type, Optional< spirv::StorageClass > storageClass={})
Converts a scalar type to a suitable type under the given targetEnv.
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass)
static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass)
Wraps the given elementType in a struct and gets the pointer to the struct.
static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType)
Returns the pointer type for the push constant storage containing elementCount 32-bit integer values.
static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type)
static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder)
Gets or inserts a global variable for a builtin within body block.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1252
Block represents an ordered list of Operations.
Definition: Block.h:30
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:182
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:190
FloatType getF32Type()
Definition: Builders.cpp:48
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:81
This class implements a pattern rewriter for use with ConversionPatterns.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:230
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:272
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
operand_type_iterator operand_type_end()
Definition: Operation.h:313
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
result_type_iterator result_type_end()
Definition: Operation.h:344
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
result_type_iterator result_type_begin()
Definition: Operation.h:343
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:395
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_type_iterator operand_type_begin()
Definition: Operation.h:312
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability)
Checks if the SPIR-V capability inquired is supported.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:58
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast_or_null() const
Definition: Types.h:275
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:99
Type getPointeeType() const
Definition: SPIRVTypes.cpp:480
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:476
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:533
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:622
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:591
SPIR-V struct type.
Definition: SPIRVTypes.h:281
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:61
MLIRContext * getContext() const
Returns the MLIRContext.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder)
Returns the value for the given builtin variable.
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool use64bitIndex
Use 64-bit integers to convert index types.