MLIR  14.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <functional>
22 
23 #define DEBUG_TYPE "mlir-spirv-conversion"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 /// `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
39  LabelT label, const spirv::TargetEnv &targetEnv,
41  for (const auto &ors : candidates) {
42  if (targetEnv.allows(ors))
43  continue;
44 
45  LLVM_DEBUG({
46  SmallVector<StringRef> extStrings;
47  for (spirv::Extension ext : ors)
48  extStrings.push_back(spirv::stringifyExtension(ext));
49 
50  llvm::dbgs() << label << " illegal: requires at least one extension in ["
51  << llvm::join(extStrings, ", ")
52  << "] but none allowed in target environment\n";
53  });
54  return failure();
55  }
56  return success();
57 }
58 
59 /// Checks that `candidates`capability requirements are possible to be satisfied
60 /// with the given `isAllowedFn`.
61 ///
62 /// `candidates` is a vector of vector for capability requirements following
63 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
64 /// convention.
65 template <typename LabelT>
67  LabelT label, const spirv::TargetEnv &targetEnv,
69  for (const auto &ors : candidates) {
70  if (targetEnv.allows(ors))
71  continue;
72 
73  LLVM_DEBUG({
74  SmallVector<StringRef> capStrings;
75  for (spirv::Capability cap : ors)
76  capStrings.push_back(spirv::stringifyCapability(cap));
77 
78  llvm::dbgs() << label << " illegal: requires at least one capability in ["
79  << llvm::join(capStrings, ", ")
80  << "] but none allowed in target environment\n";
81  });
82  return failure();
83  }
84  return success();
85 }
86 
87 /// Returns true if the given `storageClass` needs explicit layout when used in
88 /// Shader environments.
89 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
90  switch (storageClass) {
91  case spirv::StorageClass::PhysicalStorageBuffer:
92  case spirv::StorageClass::PushConstant:
93  case spirv::StorageClass::StorageBuffer:
94  case spirv::StorageClass::Uniform:
95  return true;
96  default:
97  return false;
98  }
99 }
100 
101 /// Wraps the given `elementType` in a struct and gets the pointer to the
102 /// struct. This is used to satisfy Vulkan interface requirements.
103 static spirv::PointerType
104 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
105  auto structType = needsExplicitLayout(storageClass)
106  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
107  : spirv::StructType::get(elementType);
108  return spirv::PointerType::get(structType, storageClass);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Type Conversion
113 //===----------------------------------------------------------------------===//
114 
116  return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
117 }
118 
119 /// Mapping between SPIR-V storage classes to memref memory spaces.
120 ///
121 /// Note: memref does not have a defined semantics for each memory space; it
122 /// depends on the context where it is used. There are no particular reasons
123 /// behind the number assignments; we try to follow NVVM conventions and largely
124 /// give common storage classes a smaller number. The hope is use symbolic
125 /// memory space representation eventually after memref supports it.
126 // TODO: swap Generic and StorageBuffer assignment to be more akin
127 // to NVVM.
128 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \
129  MAP_FN(spirv::StorageClass::Generic, 1) \
130  MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
131  MAP_FN(spirv::StorageClass::Workgroup, 3) \
132  MAP_FN(spirv::StorageClass::Uniform, 4) \
133  MAP_FN(spirv::StorageClass::Private, 5) \
134  MAP_FN(spirv::StorageClass::Function, 6) \
135  MAP_FN(spirv::StorageClass::PushConstant, 7) \
136  MAP_FN(spirv::StorageClass::UniformConstant, 8) \
137  MAP_FN(spirv::StorageClass::Input, 9) \
138  MAP_FN(spirv::StorageClass::Output, 10) \
139  MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
140  MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
141  MAP_FN(spirv::StorageClass::Image, 13) \
142  MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
143  MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
144  MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
145  MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
146  MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
147  MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
148  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
149  MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
150  MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
151  MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
152 
153 unsigned
155 #define STORAGE_SPACE_MAP_FN(storage, space) \
156  case storage: \
157  return space;
158 
159  switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
160 #undef STORAGE_SPACE_MAP_FN
161  llvm_unreachable("unhandled storage class!");
162 }
163 
166 #define STORAGE_SPACE_MAP_FN(storage, space) \
167  case space: \
168  return storage;
169 
170  switch (space) {
172  default:
173  return llvm::None;
174  }
175 #undef STORAGE_SPACE_MAP_FN
176 }
177 
179  return options;
180 }
181 
182 MLIRContext *SPIRVTypeConverter::getContext() const {
183  return targetEnv.getAttr().getContext();
184 }
185 
186 #undef STORAGE_SPACE_MAP_LIST
187 
188 // TODO: This is a utility function that should probably be exposed by the
189 // SPIR-V dialect. Keeping it local till the use case arises.
190 static Optional<int64_t>
192  if (type.isa<spirv::ScalarType>()) {
193  auto bitWidth = type.getIntOrFloatBitWidth();
194  // According to the SPIR-V spec:
195  // "There is no physical size or bit pattern defined for values with boolean
196  // type. If they are stored (in conjunction with OpVariable), they can only
197  // be used with logical addressing operations, not physical, and only with
198  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
199  // Private, Function, Input, and Output."
200  if (bitWidth == 1)
201  return llvm::None;
202  return bitWidth / 8;
203  }
204 
205  if (auto vecType = type.dyn_cast<VectorType>()) {
206  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
207  if (!elementSize)
208  return llvm::None;
209  return vecType.getNumElements() * elementSize.getValue();
210  }
211 
212  if (auto memRefType = type.dyn_cast<MemRefType>()) {
213  // TODO: Layout should also be controlled by the ABI attributes. For now
214  // using the layout from MemRef.
215  int64_t offset;
216  SmallVector<int64_t, 4> strides;
217  if (!memRefType.hasStaticShape() ||
218  failed(getStridesAndOffset(memRefType, strides, offset)))
219  return llvm::None;
220 
221  // To get the size of the memref object in memory, the total size is the
222  // max(stride * dimension-size) computed for all dimensions times the size
223  // of the element.
224  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
225  if (!elementSize)
226  return llvm::None;
227 
228  if (memRefType.getRank() == 0)
229  return elementSize;
230 
231  auto dims = memRefType.getShape();
232  if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
233  offset == MemRefType::getDynamicStrideOrOffset() ||
234  llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
235  return llvm::None;
236 
237  int64_t memrefSize = -1;
238  for (const auto &shape : enumerate(dims))
239  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
240 
241  return (offset + memrefSize) * elementSize.getValue();
242  }
243 
244  if (auto tensorType = type.dyn_cast<TensorType>()) {
245  if (!tensorType.hasStaticShape())
246  return llvm::None;
247 
248  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
249  if (!elementSize)
250  return llvm::None;
251 
252  int64_t size = elementSize.getValue();
253  for (auto shape : tensorType.getShape())
254  size *= shape;
255 
256  return size;
257  }
258 
259  // TODO: Add size computation for other types.
260  return llvm::None;
261 }
262 
263 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
264 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
265  const SPIRVTypeConverter::Options &options,
266  spirv::ScalarType type,
267  Optional<spirv::StorageClass> storageClass = {}) {
268  // Get extension and capability requirements for the given type.
271  type.getExtensions(extensions, storageClass);
272  type.getCapabilities(capabilities, storageClass);
273 
274  // If all requirements are met, then we can accept this type as-is.
275  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
276  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
277  return type;
278 
279  // Otherwise we need to adjust the type, which really means adjusting the
280  // bitwidth given this is a scalar type.
281 
282  if (!options.emulateNon32BitScalarTypes)
283  return nullptr;
284 
285  if (auto floatType = type.dyn_cast<FloatType>()) {
286  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
287  return Builder(targetEnv.getContext()).getF32Type();
288  }
289 
290  auto intType = type.cast<IntegerType>();
291  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
292  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
293  intType.getSignedness());
294 }
295 
296 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
297 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
298  const SPIRVTypeConverter::Options &options,
299  VectorType type,
300  Optional<spirv::StorageClass> storageClass = {}) {
301  if (type.getRank() == 1 && type.getNumElements() == 1)
302  return type.getElementType();
303 
304  if (!spirv::CompositeType::isValid(type)) {
305  // TODO: Vector types with more than four elements can be translated into
306  // array types.
307  LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
308  return nullptr;
309  }
310 
311  // Get extension and capability requirements for the given type.
314  type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
315  type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
316 
317  // If all requirements are met, then we can accept this type as-is.
318  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
319  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
320  return type;
321 
322  auto elementType = convertScalarType(
323  targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
324  storageClass);
325  if (elementType)
326  return VectorType::get(type.getShape(), elementType);
327  return nullptr;
328 }
329 
330 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
331 ///
332 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
333 /// create composite constants with OpConstantComposite to embed relative large
334 /// constant values and use OpCompositeExtract and OpCompositeInsert to
335 /// manipulate, like what we do for vectors.
336 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
337  const SPIRVTypeConverter::Options &options,
338  TensorType type) {
339  // TODO: Handle dynamic shapes.
340  if (!type.hasStaticShape()) {
341  LLVM_DEBUG(llvm::dbgs()
342  << type << " illegal: dynamic shape unimplemented\n");
343  return nullptr;
344  }
345 
346  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
347  if (!scalarType) {
348  LLVM_DEBUG(llvm::dbgs()
349  << type << " illegal: cannot convert non-scalar element type\n");
350  return nullptr;
351  }
352 
353  Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
354  Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
355  if (!scalarSize || !tensorSize) {
356  LLVM_DEBUG(llvm::dbgs()
357  << type << " illegal: cannot deduce element count\n");
358  return nullptr;
359  }
360 
361  auto arrayElemCount = *tensorSize / *scalarSize;
362  auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
363  if (!arrayElemType)
364  return nullptr;
365  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
366  if (!arrayElemSize) {
367  LLVM_DEBUG(llvm::dbgs()
368  << type << " illegal: cannot deduce converted element size\n");
369  return nullptr;
370  }
371 
372  return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
373 }
374 
376  const SPIRVTypeConverter::Options &options,
377  MemRefType type) {
378  Optional<spirv::StorageClass> storageClass =
380  type.getMemorySpaceAsInt());
381  if (!storageClass) {
382  LLVM_DEBUG(llvm::dbgs()
383  << type << " illegal: cannot convert memory space\n");
384  return nullptr;
385  }
386 
387  unsigned numBoolBits = options.boolNumBits;
388  if (numBoolBits != 8) {
389  LLVM_DEBUG(llvm::dbgs()
390  << "using non-8-bit storage for bool types unimplemented");
391  return nullptr;
392  }
393  auto elementType = IntegerType::get(type.getContext(), numBoolBits)
394  .dyn_cast<spirv::ScalarType>();
395  if (!elementType)
396  return nullptr;
397  Type arrayElemType =
398  convertScalarType(targetEnv, options, elementType, storageClass);
399  if (!arrayElemType)
400  return nullptr;
401  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
402  if (!arrayElemSize) {
403  LLVM_DEBUG(llvm::dbgs()
404  << type << " illegal: cannot deduce converted element size\n");
405  return nullptr;
406  }
407 
408  if (!type.hasStaticShape()) {
409  auto arrayType =
410  spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
411  return wrapInStructAndGetPointer(arrayType, *storageClass);
412  }
413 
414  int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
415  auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
416  auto arrayType =
417  spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
418 
419  return wrapInStructAndGetPointer(arrayType, *storageClass);
420 }
421 
422 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
423  const SPIRVTypeConverter::Options &options,
424  MemRefType type) {
425  if (type.getElementType().isa<IntegerType>() &&
426  type.getElementTypeBitWidth() == 1) {
427  return convertBoolMemrefType(targetEnv, options, type);
428  }
429 
430  Optional<spirv::StorageClass> storageClass =
432  type.getMemorySpaceAsInt());
433  if (!storageClass) {
434  LLVM_DEBUG(llvm::dbgs()
435  << type << " illegal: cannot convert memory space\n");
436  return nullptr;
437  }
438 
439  Type arrayElemType;
440  Type elementType = type.getElementType();
441  if (auto vecType = elementType.dyn_cast<VectorType>()) {
442  arrayElemType =
443  convertVectorType(targetEnv, options, vecType, storageClass);
444  } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
445  arrayElemType =
446  convertScalarType(targetEnv, options, scalarType, storageClass);
447  } else {
448  LLVM_DEBUG(
449  llvm::dbgs()
450  << type
451  << " unhandled: can only convert scalar or vector element type\n");
452  return nullptr;
453  }
454  if (!arrayElemType)
455  return nullptr;
456 
457  Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
458  if (!elementSize) {
459  LLVM_DEBUG(llvm::dbgs()
460  << type << " illegal: cannot deduce element size\n");
461  return nullptr;
462  }
463 
464  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
465  if (!arrayElemSize) {
466  LLVM_DEBUG(llvm::dbgs()
467  << type << " illegal: cannot deduce converted element size\n");
468  return nullptr;
469  }
470 
471  if (!type.hasStaticShape()) {
472  auto arrayType =
473  spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
474  return wrapInStructAndGetPointer(arrayType, *storageClass);
475  }
476 
477  Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
478  if (!memrefSize) {
479  LLVM_DEBUG(llvm::dbgs()
480  << type << " illegal: cannot deduce element count\n");
481  return nullptr;
482  }
483 
484  auto arrayElemCount = *memrefSize / *elementSize;
485 
486 
487  auto arrayType =
488  spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
489 
490  return wrapInStructAndGetPointer(arrayType, *storageClass);
491 }
492 
494  Options options)
495  : targetEnv(targetAttr), options(options) {
496  // Add conversions. The order matters here: later ones will be tried earlier.
497 
498  // Allow all SPIR-V dialect specific types. This assumes all builtin types
499  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
500  // were tried before.
501  //
502  // TODO: this assumes that the SPIR-V types are valid to use in
503  // the given target environment, which should be the case if the whole
504  // pipeline is driven by the same target environment. Still, we probably still
505  // want to validate and convert to be safe.
506  addConversion([](spirv::SPIRVType type) { return type; });
507 
508  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
509 
510  addConversion([this](IntegerType intType) -> Optional<Type> {
511  if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
512  return convertScalarType(this->targetEnv, this->options, scalarType);
513  return Type();
514  });
515 
516  addConversion([this](FloatType floatType) -> Optional<Type> {
517  if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
518  return convertScalarType(this->targetEnv, this->options, scalarType);
519  return Type();
520  });
521 
522  addConversion([this](VectorType vectorType) {
523  return convertVectorType(this->targetEnv, this->options, vectorType);
524  });
525 
526  addConversion([this](TensorType tensorType) {
527  return convertTensorType(this->targetEnv, this->options, tensorType);
528  });
529 
530  addConversion([this](MemRefType memRefType) {
531  return convertMemrefType(this->targetEnv, this->options, memRefType);
532  });
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // FuncOp Conversion Patterns
537 //===----------------------------------------------------------------------===//
538 
539 namespace {
540 /// A pattern for rewriting function signature to convert arguments of functions
541 /// to be of valid SPIR-V types.
542 class FuncOpConversion final : public OpConversionPattern<FuncOp> {
543 public:
545 
547  matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
548  ConversionPatternRewriter &rewriter) const override;
549 };
550 } // namespace
551 
553 FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
554  ConversionPatternRewriter &rewriter) const {
555  auto fnType = funcOp.getType();
556  if (fnType.getNumResults() > 1)
557  return failure();
558 
559  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
560  for (const auto &argType : enumerate(fnType.getInputs())) {
561  auto convertedType = getTypeConverter()->convertType(argType.value());
562  if (!convertedType)
563  return failure();
564  signatureConverter.addInputs(argType.index(), convertedType);
565  }
566 
567  Type resultType;
568  if (fnType.getNumResults() == 1) {
569  resultType = getTypeConverter()->convertType(fnType.getResult(0));
570  if (!resultType)
571  return failure();
572  }
573 
574  // Create the converted spv.func op.
575  auto newFuncOp = rewriter.create<spirv::FuncOp>(
576  funcOp.getLoc(), funcOp.getName(),
577  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
578  resultType ? TypeRange(resultType)
579  : TypeRange()));
580 
581  // Copy over all attributes other than the function name and type.
582  for (const auto &namedAttr : funcOp->getAttrs()) {
583  if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
584  namedAttr.getName() != SymbolTable::getSymbolAttrName())
585  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
586  }
587 
588  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
589  newFuncOp.end());
590  if (failed(rewriter.convertRegionTypes(
591  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
592  return failure();
593  rewriter.eraseOp(funcOp);
594  return success();
595 }
596 
598  RewritePatternSet &patterns) {
599  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // Builtin Variables
604 //===----------------------------------------------------------------------===//
605 
606 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
607  spirv::BuiltIn builtin) {
608  // Look through all global variables in the given `body` block and check if
609  // there is a spv.GlobalVariable that has the same `builtin` attribute.
610  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
611  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
612  spirv::SPIRVDialect::getAttributeName(
613  spirv::Decoration::BuiltIn))) {
614  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
615  if (varBuiltIn && varBuiltIn.getValue() == builtin) {
616  return varOp;
617  }
618  }
619  }
620  return nullptr;
621 }
622 
623 /// Gets name of global variable for a builtin.
624 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
625  return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
626 }
627 
628 /// Gets or inserts a global variable for a builtin within `body` block.
629 static spirv::GlobalVariableOp
630 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
631  Type integerType, OpBuilder &builder) {
632  if (auto varOp = getBuiltinVariable(body, builtin))
633  return varOp;
634 
635  OpBuilder::InsertionGuard guard(builder);
636  builder.setInsertionPointToStart(&body);
637 
638  spirv::GlobalVariableOp newVarOp;
639  switch (builtin) {
640  case spirv::BuiltIn::NumWorkgroups:
641  case spirv::BuiltIn::WorkgroupSize:
642  case spirv::BuiltIn::WorkgroupId:
643  case spirv::BuiltIn::LocalInvocationId:
644  case spirv::BuiltIn::GlobalInvocationId: {
645  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
646  spirv::StorageClass::Input);
647  std::string name = getBuiltinVarName(builtin);
648  newVarOp =
649  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
650  break;
651  }
652  case spirv::BuiltIn::SubgroupId:
653  case spirv::BuiltIn::NumSubgroups:
654  case spirv::BuiltIn::SubgroupSize: {
655  auto ptrType =
656  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
657  std::string name = getBuiltinVarName(builtin);
658  newVarOp =
659  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
660  break;
661  }
662  default:
663  emitError(loc, "unimplemented builtin variable generation for ")
664  << stringifyBuiltIn(builtin);
665  }
666  return newVarOp;
667 }
668 
670  spirv::BuiltIn builtin,
671  Type integerType,
672  OpBuilder &builder) {
674  if (!parent) {
675  op->emitError("expected operation to be within a module-like op");
676  return nullptr;
677  }
678 
679  spirv::GlobalVariableOp varOp =
680  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
681  builtin, integerType, builder);
682  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
683  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Push constant storage
688 //===----------------------------------------------------------------------===//
689 
690 /// Returns the pointer type for the push constant storage containing
691 /// `elementCount` 32-bit integer values.
692 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
693  Builder &builder,
694  Type indexType) {
695  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
696  /*stride=*/4);
697  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
698  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
699 }
700 
701 /// Returns the push constant varible containing `elementCount` 32-bit integer
702 /// values in `body`. Returns null op if such an op does not exit.
703 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
704  unsigned elementCount) {
705  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
706  auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
707  if (!ptrType)
708  continue;
709 
710  // Note that Vulkan requires "There must be no more than one push constant
711  // block statically used per shader entry point." So we should always reuse
712  // the existing one.
713  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
714  auto numElements = ptrType.getPointeeType()
716  .getElementType(0)
718  .getNumElements();
719  if (numElements == elementCount)
720  return varOp;
721  }
722  }
723  return nullptr;
724 }
725 
726 /// Gets or inserts a global variable for push constant storage containing
727 /// `elementCount` 32-bit integer values in `block`.
728 static spirv::GlobalVariableOp
730  unsigned elementCount, OpBuilder &b,
731  Type indexType) {
732  if (auto varOp = getPushConstantVariable(block, elementCount))
733  return varOp;
734 
735  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
736  auto type = getPushConstantStorageType(elementCount, builder, indexType);
737  const char *name = "__push_constant_var__";
738  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
739  /*initializer=*/nullptr);
740 }
741 
742 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
743  unsigned offset, Type integerType,
744  OpBuilder &builder) {
745  Location loc = op->getLoc();
747  if (!parent) {
748  op->emitError("expected operation to be within a module-like op");
749  return nullptr;
750  }
751 
752  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
753  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
754 
755  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
756  Value offsetOp = builder.create<spirv::ConstantOp>(
757  loc, integerType, builder.getI32IntegerAttr(offset));
758  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
759  auto acOp = builder.create<spirv::AccessChainOp>(
760  loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
761  return builder.create<spirv::LoadOp>(loc, acOp);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // Index calculation
766 //===----------------------------------------------------------------------===//
767 
769  int64_t offset, Type integerType,
770  Location loc, OpBuilder &builder) {
771  assert(indices.size() == strides.size() &&
772  "must provide indices for all dimensions");
773 
774  // TODO: Consider moving to use affine.apply and patterns converting
775  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
776  // broken down into progressive small steps so we can have intermediate steps
777  // using other dialects. At the moment SPIR-V is the final sink.
778 
779  Value linearizedIndex = builder.create<spirv::ConstantOp>(
780  loc, integerType, IntegerAttr::get(integerType, offset));
781  for (const auto &index : llvm::enumerate(indices)) {
782  Value strideVal = builder.create<spirv::ConstantOp>(
783  loc, integerType,
784  IntegerAttr::get(integerType, strides[index.index()]));
785  Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
786  linearizedIndex =
787  builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
788  }
789  return linearizedIndex;
790 }
791 
792 spirv::AccessChainOp mlir::spirv::getElementPtr(
793  SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
794  ValueRange indices, Location loc, OpBuilder &builder) {
795  // Get base and offset of the MemRefType and verify they are static.
796 
797  int64_t offset;
798  SmallVector<int64_t, 4> strides;
799  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
800  llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
801  offset == MemRefType::getDynamicStrideOrOffset()) {
802  return nullptr;
803  }
804 
805  auto indexType = typeConverter.getIndexType();
806 
807  SmallVector<Value, 2> linearizedIndices;
808  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
809 
810  // Add a '0' at the start to index into the struct.
811  linearizedIndices.push_back(zero);
812 
813  if (baseType.getRank() == 0) {
814  linearizedIndices.push_back(zero);
815  } else {
816  linearizedIndices.push_back(
817  linearizeIndex(indices, strides, offset, indexType, loc, builder));
818  }
819  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // SPIR-V ConversionTarget
824 //===----------------------------------------------------------------------===//
825 
826 std::unique_ptr<SPIRVConversionTarget>
828  std::unique_ptr<SPIRVConversionTarget> target(
829  // std::make_unique does not work here because the constructor is private.
830  new SPIRVConversionTarget(targetAttr));
831  SPIRVConversionTarget *targetPtr = target.get();
832  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
833  // We need to capture the raw pointer here because it is stable:
834  // target will be destroyed once this function is returned.
835  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
836  return target;
837 }
838 
839 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
840  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
841 
842 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
843  // Make sure this op is available at the given version. Ops not implementing
844  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
845  // SPIR-V versions.
846  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
847  Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
848  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
849  LLVM_DEBUG(llvm::dbgs()
850  << op->getName() << " illegal: requiring min version "
851  << spirv::stringifyVersion(*minVersion) << "\n");
852  return false;
853  }
854  }
855  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
856  Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
857  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
858  LLVM_DEBUG(llvm::dbgs()
859  << op->getName() << " illegal: requiring max version "
860  << spirv::stringifyVersion(*maxVersion) << "\n");
861  return false;
862  }
863  }
864 
865  // Make sure this op's required extensions are allowed to use. Ops not
866  // implementing QueryExtensionInterface do not require extensions to be
867  // available.
868  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
869  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
870  extensions.getExtensions())))
871  return false;
872 
873  // Make sure this op's required extensions are allowed to use. Ops not
874  // implementing QueryCapabilityInterface do not require capabilities to be
875  // available.
876  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
877  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
878  capabilities.getCapabilities())))
879  return false;
880 
881  SmallVector<Type, 4> valueTypes;
882  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
883  valueTypes.append(op->result_type_begin(), op->result_type_end());
884 
885  // Ensure that all types have been converted to SPIRV types.
886  if (llvm::any_of(valueTypes,
887  [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
888  return false;
889 
890  // Special treatment for global variables, whose type requirements are
891  // conveyed by type attributes.
892  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
893  valueTypes.push_back(globalVar.type());
894 
895  // Make sure the op's operands/results use types that are allowed by the
896  // target environment.
897  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
898  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
899  for (Type valueType : valueTypes) {
900  typeExtensions.clear();
901  valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
902  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
903  typeExtensions)))
904  return false;
905 
906  typeCapabilities.clear();
907  valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
908  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
909  typeCapabilities)))
910  return false;
911  }
912 
913  return true;
914 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
const Options & getOptions() const
Returns the options controlling the SPIR-V type converter.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
Block represents an ordered list of Operations.
Definition: Block.h:29
bool use64bitIndex
Use 64-bit integers to convert index types.
Block & front()
Definition: Region.h:65
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:506
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type)
static unsigned getMemorySpaceForStorageClass(spirv::StorageClass)
Returns the corresponding memory space for memref given a SPIR-V storage class.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:61
static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, TensorType type)
Converts a tensor type to a suitable type under the given targetEnv.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, Options options={})
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:209
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder)
Gets or inserts a global variable for a builtin within body block.
bool allows(Capability) const
Returns true if the given capability is allowed.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:948
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, VectorType type, Optional< spirv::StorageClass > storageClass={})
Converts a vector type to a suitable type under the given targetEnv.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
iterator begin()
Definition: Region.h:55
#define STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder)
Returns the value for the given builtin variable.
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount)
Returns the push constant varible containing elementCount 32-bit integer values in body...
MLIRContext * getContext() const
Returns the MLIRContext.
U dyn_cast() const
Definition: Types.h:244
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn&#39;t have a listener...
Definition: Builders.h:251
result_type_iterator result_type_end()
Definition: Operation.h:296
operand_type_iterator operand_type_begin()
Definition: Operation.h:264
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
operand_type_iterator operand_type_end()
Definition: Operation.h:265
static std::string getBuiltinVarName(spirv::BuiltIn builtin)
Gets name of global variable for a builtin.
static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType)
Gets or inserts a global variable for push constant storage containing elementCount 32-bit integer va...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
static Optional< int64_t > getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type)
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void addConversion(FnT &&callback)
Register a conversion function.
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
result_type_iterator result_type_begin()
Definition: Operation.h:295
static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType)
Returns the pointer type for the push constant storage containing elementCount 32-bit integer values...
static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:678
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides...
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr...
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:537
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Type getElementType() const
Returns the element type of this tensor type.
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass)
Wraps the given elementType in a struct and gets the pointer to the struct.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
#define STORAGE_SPACE_MAP_FN(storage, space)
unsigned boolNumBits
The number of bits to store a boolean value.
This class describes a specific conversion target.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
static Optional< spirv::StorageClass > getStorageClassForMemorySpace(unsigned space)
Returns the SPIR-V storage class given a memory space for memref.
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type)
static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, spirv::ScalarType type, Optional< spirv::StorageClass > storageClass={})
Converts a scalar type to a suitable type under the given targetEnv.
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
bool isa() const
Definition: Types.h:234
Version getVersion() const
static bool needsExplicitLayout(spirv::StorageClass storageClass)
Returns true if the given storageClass needs explicit layout when used in Shader environments.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
static LogicalResult checkCapabilityRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates)
Checks that candidatescapability requirements are possible to be satisfied with the given isAllowedFn...
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:429
MLIRContext * getContext() const
Definition: PatternMatch.h:906
bool emulateNon32BitScalarTypes
Whether to emulate non-32-bit scalar types with 32-bit scalar types if no native support.
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
Type conversion from builtin types to SPIR-V types for shader interface.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
static LogicalResult checkExtensionRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv...
U cast() const
Definition: Types.h:250
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:97