MLIR  15.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Debug.h"
21 
22 #include <functional>
23 
24 #define DEBUG_TYPE "mlir-spirv-conversion"
25 
26 using namespace mlir;
27 
28 //===----------------------------------------------------------------------===//
29 // Utility functions
30 //===----------------------------------------------------------------------===//
31 
32 /// Checks that `candidates` extension requirements are possible to be satisfied
33 /// with the given `targetEnv`.
34 ///
35 /// `candidates` is a vector of vector for extension requirements following
36 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
37 /// convention.
38 template <typename LabelT>
40  LabelT label, const spirv::TargetEnv &targetEnv,
42  for (const auto &ors : candidates) {
43  if (targetEnv.allows(ors))
44  continue;
45 
46  LLVM_DEBUG({
47  SmallVector<StringRef> extStrings;
48  for (spirv::Extension ext : ors)
49  extStrings.push_back(spirv::stringifyExtension(ext));
50 
51  llvm::dbgs() << label << " illegal: requires at least one extension in ["
52  << llvm::join(extStrings, ", ")
53  << "] but none allowed in target environment\n";
54  });
55  return failure();
56  }
57  return success();
58 }
59 
60 /// Checks that `candidates`capability requirements are possible to be satisfied
61 /// with the given `isAllowedFn`.
62 ///
63 /// `candidates` is a vector of vector for capability requirements following
64 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
65 /// convention.
66 template <typename LabelT>
68  LabelT label, const spirv::TargetEnv &targetEnv,
70  for (const auto &ors : candidates) {
71  if (targetEnv.allows(ors))
72  continue;
73 
74  LLVM_DEBUG({
75  SmallVector<StringRef> capStrings;
76  for (spirv::Capability cap : ors)
77  capStrings.push_back(spirv::stringifyCapability(cap));
78 
79  llvm::dbgs() << label << " illegal: requires at least one capability in ["
80  << llvm::join(capStrings, ", ")
81  << "] but none allowed in target environment\n";
82  });
83  return failure();
84  }
85  return success();
86 }
87 
88 /// Returns true if the given `storageClass` needs explicit layout when used in
89 /// Shader environments.
90 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
91  switch (storageClass) {
92  case spirv::StorageClass::PhysicalStorageBuffer:
93  case spirv::StorageClass::PushConstant:
94  case spirv::StorageClass::StorageBuffer:
95  case spirv::StorageClass::Uniform:
96  return true;
97  default:
98  return false;
99  }
100 }
101 
102 /// Wraps the given `elementType` in a struct and gets the pointer to the
103 /// struct. This is used to satisfy Vulkan interface requirements.
104 static spirv::PointerType
105 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
106  auto structType = needsExplicitLayout(storageClass)
107  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
108  : spirv::StructType::get(elementType);
109  return spirv::PointerType::get(structType, storageClass);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Type Conversion
114 //===----------------------------------------------------------------------===//
115 
117  return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
118 }
119 
120 /// Mapping between SPIR-V storage classes to memref memory spaces.
121 ///
122 /// Note: memref does not have a defined semantics for each memory space; it
123 /// depends on the context where it is used. There are no particular reasons
124 /// behind the number assignments; we try to follow NVVM conventions and largely
125 /// give common storage classes a smaller number. The hope is use symbolic
126 /// memory space representation eventually after memref supports it.
127 // TODO: swap Generic and StorageBuffer assignment to be more akin
128 // to NVVM.
129 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \
130  MAP_FN(spirv::StorageClass::Generic, 1) \
131  MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
132  MAP_FN(spirv::StorageClass::Workgroup, 3) \
133  MAP_FN(spirv::StorageClass::Uniform, 4) \
134  MAP_FN(spirv::StorageClass::Private, 5) \
135  MAP_FN(spirv::StorageClass::Function, 6) \
136  MAP_FN(spirv::StorageClass::PushConstant, 7) \
137  MAP_FN(spirv::StorageClass::UniformConstant, 8) \
138  MAP_FN(spirv::StorageClass::Input, 9) \
139  MAP_FN(spirv::StorageClass::Output, 10) \
140  MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
141  MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
142  MAP_FN(spirv::StorageClass::Image, 13) \
143  MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
144  MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
145  MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
146  MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
147  MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
148  MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
149  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
150  MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
151  MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
152  MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
153 
154 unsigned
156 #define STORAGE_SPACE_MAP_FN(storage, space) \
157  case storage: \
158  return space;
159 
160  switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
161 #undef STORAGE_SPACE_MAP_FN
162  llvm_unreachable("unhandled storage class!");
163 }
164 
167 #define STORAGE_SPACE_MAP_FN(storage, space) \
168  case space: \
169  return storage;
170 
171  switch (space) {
173  default:
174  return llvm::None;
175  }
176 #undef STORAGE_SPACE_MAP_FN
177 }
178 
180  return options;
181 }
182 
183 MLIRContext *SPIRVTypeConverter::getContext() const {
184  return targetEnv.getAttr().getContext();
185 }
186 
187 #undef STORAGE_SPACE_MAP_LIST
188 
189 // TODO: This is a utility function that should probably be exposed by the
190 // SPIR-V dialect. Keeping it local till the use case arises.
191 static Optional<int64_t>
193  if (type.isa<spirv::ScalarType>()) {
194  auto bitWidth = type.getIntOrFloatBitWidth();
195  // According to the SPIR-V spec:
196  // "There is no physical size or bit pattern defined for values with boolean
197  // type. If they are stored (in conjunction with OpVariable), they can only
198  // be used with logical addressing operations, not physical, and only with
199  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
200  // Private, Function, Input, and Output."
201  if (bitWidth == 1)
202  return llvm::None;
203  return bitWidth / 8;
204  }
205 
206  if (auto vecType = type.dyn_cast<VectorType>()) {
207  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
208  if (!elementSize)
209  return llvm::None;
210  return vecType.getNumElements() * *elementSize;
211  }
212 
213  if (auto memRefType = type.dyn_cast<MemRefType>()) {
214  // TODO: Layout should also be controlled by the ABI attributes. For now
215  // using the layout from MemRef.
216  int64_t offset;
217  SmallVector<int64_t, 4> strides;
218  if (!memRefType.hasStaticShape() ||
219  failed(getStridesAndOffset(memRefType, strides, offset)))
220  return llvm::None;
221 
222  // To get the size of the memref object in memory, the total size is the
223  // max(stride * dimension-size) computed for all dimensions times the size
224  // of the element.
225  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
226  if (!elementSize)
227  return llvm::None;
228 
229  if (memRefType.getRank() == 0)
230  return elementSize;
231 
232  auto dims = memRefType.getShape();
233  if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
234  offset == MemRefType::getDynamicStrideOrOffset() ||
235  llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
236  return llvm::None;
237 
238  int64_t memrefSize = -1;
239  for (const auto &shape : enumerate(dims))
240  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
241 
242  return (offset + memrefSize) * *elementSize;
243  }
244 
245  if (auto tensorType = type.dyn_cast<TensorType>()) {
246  if (!tensorType.hasStaticShape())
247  return llvm::None;
248 
249  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
250  if (!elementSize)
251  return llvm::None;
252 
253  int64_t size = *elementSize;
254  for (auto shape : tensorType.getShape())
255  size *= shape;
256 
257  return size;
258  }
259 
260  // TODO: Add size computation for other types.
261  return llvm::None;
262 }
263 
264 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
265 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
266  const SPIRVTypeConverter::Options &options,
267  spirv::ScalarType type,
268  Optional<spirv::StorageClass> storageClass = {}) {
269  // Get extension and capability requirements for the given type.
272  type.getExtensions(extensions, storageClass);
273  type.getCapabilities(capabilities, storageClass);
274 
275  // If all requirements are met, then we can accept this type as-is.
276  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
277  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
278  return type;
279 
280  // Otherwise we need to adjust the type, which really means adjusting the
281  // bitwidth given this is a scalar type.
282 
283  if (!options.emulateNon32BitScalarTypes)
284  return nullptr;
285 
286  if (auto floatType = type.dyn_cast<FloatType>()) {
287  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
288  return Builder(targetEnv.getContext()).getF32Type();
289  }
290 
291  auto intType = type.cast<IntegerType>();
292  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
293  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
294  intType.getSignedness());
295 }
296 
297 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
298 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
299  const SPIRVTypeConverter::Options &options,
300  VectorType type,
301  Optional<spirv::StorageClass> storageClass = {}) {
302  if (type.getRank() == 1 && type.getNumElements() == 1)
303  return type.getElementType();
304 
305  if (!spirv::CompositeType::isValid(type)) {
306  // TODO: Vector types with more than four elements can be translated into
307  // array types.
308  LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
309  return nullptr;
310  }
311 
312  // Get extension and capability requirements for the given type.
315  type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
316  type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
317 
318  // If all requirements are met, then we can accept this type as-is.
319  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
320  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
321  return type;
322 
323  auto elementType = convertScalarType(
324  targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
325  storageClass);
326  if (elementType)
327  return VectorType::get(type.getShape(), elementType);
328  return nullptr;
329 }
330 
331 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
332 ///
333 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
334 /// create composite constants with OpConstantComposite to embed relative large
335 /// constant values and use OpCompositeExtract and OpCompositeInsert to
336 /// manipulate, like what we do for vectors.
337 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
338  const SPIRVTypeConverter::Options &options,
339  TensorType type) {
340  // TODO: Handle dynamic shapes.
341  if (!type.hasStaticShape()) {
342  LLVM_DEBUG(llvm::dbgs()
343  << type << " illegal: dynamic shape unimplemented\n");
344  return nullptr;
345  }
346 
347  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
348  if (!scalarType) {
349  LLVM_DEBUG(llvm::dbgs()
350  << type << " illegal: cannot convert non-scalar element type\n");
351  return nullptr;
352  }
353 
354  Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
355  Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
356  if (!scalarSize || !tensorSize) {
357  LLVM_DEBUG(llvm::dbgs()
358  << type << " illegal: cannot deduce element count\n");
359  return nullptr;
360  }
361 
362  auto arrayElemCount = *tensorSize / *scalarSize;
363  auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
364  if (!arrayElemType)
365  return nullptr;
366  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
367  if (!arrayElemSize) {
368  LLVM_DEBUG(llvm::dbgs()
369  << type << " illegal: cannot deduce converted element size\n");
370  return nullptr;
371  }
372 
373  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
374 }
375 
377  const SPIRVTypeConverter::Options &options,
378  MemRefType type) {
379  Optional<spirv::StorageClass> storageClass =
381  type.getMemorySpaceAsInt());
382  if (!storageClass) {
383  LLVM_DEBUG(llvm::dbgs()
384  << type << " illegal: cannot convert memory space\n");
385  return nullptr;
386  }
387 
388  unsigned numBoolBits = options.boolNumBits;
389  if (numBoolBits != 8) {
390  LLVM_DEBUG(llvm::dbgs()
391  << "using non-8-bit storage for bool types unimplemented");
392  return nullptr;
393  }
394  auto elementType = IntegerType::get(type.getContext(), numBoolBits)
395  .dyn_cast<spirv::ScalarType>();
396  if (!elementType)
397  return nullptr;
398  Type arrayElemType =
399  convertScalarType(targetEnv, options, elementType, storageClass);
400  if (!arrayElemType)
401  return nullptr;
402  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
403  if (!arrayElemSize) {
404  LLVM_DEBUG(llvm::dbgs()
405  << type << " illegal: cannot deduce converted element size\n");
406  return nullptr;
407  }
408 
409  if (!type.hasStaticShape()) {
410  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
411  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
412  return wrapInStructAndGetPointer(arrayType, *storageClass);
413  }
414 
415  int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
416  auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
417  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
418  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
419 
420  return wrapInStructAndGetPointer(arrayType, *storageClass);
421 }
422 
423 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
424  const SPIRVTypeConverter::Options &options,
425  MemRefType type) {
426  if (type.getElementType().isa<IntegerType>() &&
427  type.getElementTypeBitWidth() == 1) {
428  return convertBoolMemrefType(targetEnv, options, type);
429  }
430 
431  Optional<spirv::StorageClass> storageClass =
433  type.getMemorySpaceAsInt());
434  if (!storageClass) {
435  LLVM_DEBUG(llvm::dbgs()
436  << type << " illegal: cannot convert memory space\n");
437  return nullptr;
438  }
439 
440  Type arrayElemType;
441  Type elementType = type.getElementType();
442  if (auto vecType = elementType.dyn_cast<VectorType>()) {
443  arrayElemType =
444  convertVectorType(targetEnv, options, vecType, storageClass);
445  } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
446  arrayElemType =
447  convertScalarType(targetEnv, options, scalarType, storageClass);
448  } else {
449  LLVM_DEBUG(
450  llvm::dbgs()
451  << type
452  << " unhandled: can only convert scalar or vector element type\n");
453  return nullptr;
454  }
455  if (!arrayElemType)
456  return nullptr;
457 
458  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
459  if (!arrayElemSize) {
460  LLVM_DEBUG(llvm::dbgs()
461  << type << " illegal: cannot deduce converted element size\n");
462  return nullptr;
463  }
464 
465  if (!type.hasStaticShape()) {
466  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
467  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
468  return wrapInStructAndGetPointer(arrayType, *storageClass);
469  }
470 
471  Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
472  if (!memrefSize) {
473  LLVM_DEBUG(llvm::dbgs()
474  << type << " illegal: cannot deduce element count\n");
475  return nullptr;
476  }
477 
478  auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
479  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
480  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
481 
482  return wrapInStructAndGetPointer(arrayType, *storageClass);
483 }
484 
486  Options options)
487  : targetEnv(targetAttr), options(options) {
488  // Add conversions. The order matters here: later ones will be tried earlier.
489 
490  // Allow all SPIR-V dialect specific types. This assumes all builtin types
491  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
492  // were tried before.
493  //
494  // TODO: this assumes that the SPIR-V types are valid to use in
495  // the given target environment, which should be the case if the whole
496  // pipeline is driven by the same target environment. Still, we probably still
497  // want to validate and convert to be safe.
498  addConversion([](spirv::SPIRVType type) { return type; });
499 
500  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
501 
502  addConversion([this](IntegerType intType) -> Optional<Type> {
503  if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
504  return convertScalarType(this->targetEnv, this->options, scalarType);
505  return Type();
506  });
507 
508  addConversion([this](FloatType floatType) -> Optional<Type> {
509  if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
510  return convertScalarType(this->targetEnv, this->options, scalarType);
511  return Type();
512  });
513 
514  addConversion([this](VectorType vectorType) {
515  return convertVectorType(this->targetEnv, this->options, vectorType);
516  });
517 
518  addConversion([this](TensorType tensorType) {
519  return convertTensorType(this->targetEnv, this->options, tensorType);
520  });
521 
522  addConversion([this](MemRefType memRefType) {
523  return convertMemrefType(this->targetEnv, this->options, memRefType);
524  });
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // func::FuncOp Conversion Patterns
529 //===----------------------------------------------------------------------===//
530 
531 namespace {
532 /// A pattern for rewriting function signature to convert arguments of functions
533 /// to be of valid SPIR-V types.
534 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
535 public:
537 
539  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
540  ConversionPatternRewriter &rewriter) const override;
541 };
542 } // namespace
543 
545 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
546  ConversionPatternRewriter &rewriter) const {
547  auto fnType = funcOp.getFunctionType();
548  if (fnType.getNumResults() > 1)
549  return failure();
550 
551  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
552  for (const auto &argType : enumerate(fnType.getInputs())) {
553  auto convertedType = getTypeConverter()->convertType(argType.value());
554  if (!convertedType)
555  return failure();
556  signatureConverter.addInputs(argType.index(), convertedType);
557  }
558 
559  Type resultType;
560  if (fnType.getNumResults() == 1) {
561  resultType = getTypeConverter()->convertType(fnType.getResult(0));
562  if (!resultType)
563  return failure();
564  }
565 
566  // Create the converted spv.func op.
567  auto newFuncOp = rewriter.create<spirv::FuncOp>(
568  funcOp.getLoc(), funcOp.getName(),
569  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
570  resultType ? TypeRange(resultType)
571  : TypeRange()));
572 
573  // Copy over all attributes other than the function name and type.
574  for (const auto &namedAttr : funcOp->getAttrs()) {
575  if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
576  namedAttr.getName() != SymbolTable::getSymbolAttrName())
577  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
578  }
579 
580  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
581  newFuncOp.end());
582  if (failed(rewriter.convertRegionTypes(
583  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
584  return failure();
585  rewriter.eraseOp(funcOp);
586  return success();
587 }
588 
590  RewritePatternSet &patterns) {
591  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // Builtin Variables
596 //===----------------------------------------------------------------------===//
597 
598 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
599  spirv::BuiltIn builtin) {
600  // Look through all global variables in the given `body` block and check if
601  // there is a spv.GlobalVariable that has the same `builtin` attribute.
602  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
603  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
604  spirv::SPIRVDialect::getAttributeName(
605  spirv::Decoration::BuiltIn))) {
606  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
607  if (varBuiltIn && *varBuiltIn == builtin) {
608  return varOp;
609  }
610  }
611  }
612  return nullptr;
613 }
614 
615 /// Gets name of global variable for a builtin.
616 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
617  return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
618 }
619 
620 /// Gets or inserts a global variable for a builtin within `body` block.
621 static spirv::GlobalVariableOp
622 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
623  Type integerType, OpBuilder &builder) {
624  if (auto varOp = getBuiltinVariable(body, builtin))
625  return varOp;
626 
627  OpBuilder::InsertionGuard guard(builder);
628  builder.setInsertionPointToStart(&body);
629 
630  spirv::GlobalVariableOp newVarOp;
631  switch (builtin) {
632  case spirv::BuiltIn::NumWorkgroups:
633  case spirv::BuiltIn::WorkgroupSize:
634  case spirv::BuiltIn::WorkgroupId:
635  case spirv::BuiltIn::LocalInvocationId:
636  case spirv::BuiltIn::GlobalInvocationId: {
637  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
638  spirv::StorageClass::Input);
639  std::string name = getBuiltinVarName(builtin);
640  newVarOp =
641  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
642  break;
643  }
644  case spirv::BuiltIn::SubgroupId:
645  case spirv::BuiltIn::NumSubgroups:
646  case spirv::BuiltIn::SubgroupSize: {
647  auto ptrType =
648  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
649  std::string name = getBuiltinVarName(builtin);
650  newVarOp =
651  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
652  break;
653  }
654  default:
655  emitError(loc, "unimplemented builtin variable generation for ")
656  << stringifyBuiltIn(builtin);
657  }
658  return newVarOp;
659 }
660 
662  spirv::BuiltIn builtin,
663  Type integerType,
664  OpBuilder &builder) {
666  if (!parent) {
667  op->emitError("expected operation to be within a module-like op");
668  return nullptr;
669  }
670 
671  spirv::GlobalVariableOp varOp =
672  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
673  builtin, integerType, builder);
674  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
675  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // Push constant storage
680 //===----------------------------------------------------------------------===//
681 
682 /// Returns the pointer type for the push constant storage containing
683 /// `elementCount` 32-bit integer values.
684 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
685  Builder &builder,
686  Type indexType) {
687  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
688  /*stride=*/4);
689  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
690  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
691 }
692 
693 /// Returns the push constant varible containing `elementCount` 32-bit integer
694 /// values in `body`. Returns null op if such an op does not exit.
695 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
696  unsigned elementCount) {
697  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
698  auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
699  if (!ptrType)
700  continue;
701 
702  // Note that Vulkan requires "There must be no more than one push constant
703  // block statically used per shader entry point." So we should always reuse
704  // the existing one.
705  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
706  auto numElements = ptrType.getPointeeType()
708  .getElementType(0)
710  .getNumElements();
711  if (numElements == elementCount)
712  return varOp;
713  }
714  }
715  return nullptr;
716 }
717 
718 /// Gets or inserts a global variable for push constant storage containing
719 /// `elementCount` 32-bit integer values in `block`.
720 static spirv::GlobalVariableOp
722  unsigned elementCount, OpBuilder &b,
723  Type indexType) {
724  if (auto varOp = getPushConstantVariable(block, elementCount))
725  return varOp;
726 
727  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
728  auto type = getPushConstantStorageType(elementCount, builder, indexType);
729  const char *name = "__push_constant_var__";
730  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
731  /*initializer=*/nullptr);
732 }
733 
734 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
735  unsigned offset, Type integerType,
736  OpBuilder &builder) {
737  Location loc = op->getLoc();
739  if (!parent) {
740  op->emitError("expected operation to be within a module-like op");
741  return nullptr;
742  }
743 
744  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
745  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
746 
747  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
748  Value offsetOp = builder.create<spirv::ConstantOp>(
749  loc, integerType, builder.getI32IntegerAttr(offset));
750  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
751  auto acOp = builder.create<spirv::AccessChainOp>(
752  loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
753  return builder.create<spirv::LoadOp>(loc, acOp);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // Index calculation
758 //===----------------------------------------------------------------------===//
759 
761  int64_t offset, Type integerType,
762  Location loc, OpBuilder &builder) {
763  assert(indices.size() == strides.size() &&
764  "must provide indices for all dimensions");
765 
766  // TODO: Consider moving to use affine.apply and patterns converting
767  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
768  // broken down into progressive small steps so we can have intermediate steps
769  // using other dialects. At the moment SPIR-V is the final sink.
770 
771  Value linearizedIndex = builder.create<spirv::ConstantOp>(
772  loc, integerType, IntegerAttr::get(integerType, offset));
773  for (const auto &index : llvm::enumerate(indices)) {
774  Value strideVal = builder.create<spirv::ConstantOp>(
775  loc, integerType,
776  IntegerAttr::get(integerType, strides[index.index()]));
777  Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
778  linearizedIndex =
779  builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
780  }
781  return linearizedIndex;
782 }
783 
784 spirv::AccessChainOp mlir::spirv::getElementPtr(
785  SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
786  ValueRange indices, Location loc, OpBuilder &builder) {
787  // Get base and offset of the MemRefType and verify they are static.
788 
789  int64_t offset;
790  SmallVector<int64_t, 4> strides;
791  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
792  llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
793  offset == MemRefType::getDynamicStrideOrOffset()) {
794  return nullptr;
795  }
796 
797  auto indexType = typeConverter.getIndexType();
798 
799  SmallVector<Value, 2> linearizedIndices;
800  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
801 
802  // Add a '0' at the start to index into the struct.
803  linearizedIndices.push_back(zero);
804 
805  if (baseType.getRank() == 0) {
806  linearizedIndices.push_back(zero);
807  } else {
808  linearizedIndices.push_back(
809  linearizeIndex(indices, strides, offset, indexType, loc, builder));
810  }
811  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
812 }
813 
814 //===----------------------------------------------------------------------===//
815 // SPIR-V ConversionTarget
816 //===----------------------------------------------------------------------===//
817 
818 std::unique_ptr<SPIRVConversionTarget>
820  std::unique_ptr<SPIRVConversionTarget> target(
821  // std::make_unique does not work here because the constructor is private.
822  new SPIRVConversionTarget(targetAttr));
823  SPIRVConversionTarget *targetPtr = target.get();
824  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
825  // We need to capture the raw pointer here because it is stable:
826  // target will be destroyed once this function is returned.
827  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
828  return target;
829 }
830 
831 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
832  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
833 
834 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
835  // Make sure this op is available at the given version. Ops not implementing
836  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
837  // SPIR-V versions.
838  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
839  Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
840  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
841  LLVM_DEBUG(llvm::dbgs()
842  << op->getName() << " illegal: requiring min version "
843  << spirv::stringifyVersion(*minVersion) << "\n");
844  return false;
845  }
846  }
847  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
848  Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
849  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
850  LLVM_DEBUG(llvm::dbgs()
851  << op->getName() << " illegal: requiring max version "
852  << spirv::stringifyVersion(*maxVersion) << "\n");
853  return false;
854  }
855  }
856 
857  // Make sure this op's required extensions are allowed to use. Ops not
858  // implementing QueryExtensionInterface do not require extensions to be
859  // available.
860  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
861  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
862  extensions.getExtensions())))
863  return false;
864 
865  // Make sure this op's required extensions are allowed to use. Ops not
866  // implementing QueryCapabilityInterface do not require capabilities to be
867  // available.
868  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
869  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
870  capabilities.getCapabilities())))
871  return false;
872 
873  SmallVector<Type, 4> valueTypes;
874  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
875  valueTypes.append(op->result_type_begin(), op->result_type_end());
876 
877  // Ensure that all types have been converted to SPIRV types.
878  if (llvm::any_of(valueTypes,
879  [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
880  return false;
881 
882  // Special treatment for global variables, whose type requirements are
883  // conveyed by type attributes.
884  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
885  valueTypes.push_back(globalVar.type());
886 
887  // Make sure the op's operands/results use types that are allowed by the
888  // target environment.
889  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
890  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
891  for (Type valueType : valueTypes) {
892  typeExtensions.clear();
893  valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
894  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
895  typeExtensions)))
896  return false;
897 
898  typeCapabilities.clear();
899  valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
900  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
901  typeCapabilities)))
902  return false;
903  }
904 
905  return true;
906 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
const Options & getOptions() const
Returns the options controlling the SPIR-V type converter.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
Block represents an ordered list of Operations.
Definition: Block.h:29
bool use64bitIndex
Use 64-bit integers to convert index types.
Block & front()
Definition: Region.h:65
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:506
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type)
static unsigned getMemorySpaceForStorageClass(spirv::StorageClass)
Returns the corresponding memory space for memref given a SPIR-V storage class.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:61
static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, TensorType type)
Converts a tensor type to a suitable type under the given targetEnv.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, Options options={})
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:216
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder)
Gets or inserts a global variable for a builtin within body block.
bool allows(Capability) const
Returns true if the given capability is allowed.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:952
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, VectorType type, Optional< spirv::StorageClass > storageClass={})
Converts a vector type to a suitable type under the given targetEnv.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
iterator begin()
Definition: Region.h:55
#define STORAGE_SPACE_MAP_LIST(MAP_FN)
Mapping between SPIR-V storage classes to memref memory spaces.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder)
Returns the value for the given builtin variable.
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount)
Returns the push constant varible containing elementCount 32-bit integer values in body...
MLIRContext * getContext() const
Returns the MLIRContext.
U dyn_cast() const
Definition: Types.h:256
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:172
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn&#39;t have a listener...
Definition: Builders.h:258
result_type_iterator result_type_end()
Definition: Operation.h:351
operand_type_iterator operand_type_begin()
Definition: Operation.h:319
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
operand_type_iterator operand_type_end()
Definition: Operation.h:320
static std::string getBuiltinVarName(spirv::BuiltIn builtin)
Gets name of global variable for a builtin.
static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType)
Gets or inserts a global variable for push constant storage containing elementCount 32-bit integer va...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
static Optional< int64_t > getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type)
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void addConversion(FnT &&callback)
Register a conversion function.
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
result_type_iterator result_type_begin()
Definition: Operation.h:350
static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType)
Returns the pointer type for the push constant storage containing elementCount 32-bit integer values...
static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides...
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr...
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:537
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Type getElementType() const
Returns the element type of this tensor type.
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass)
Wraps the given elementType in a struct and gets the pointer to the struct.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
#define STORAGE_SPACE_MAP_FN(storage, space)
unsigned boolNumBits
The number of bits to store a boolean value.
This class describes a specific conversion target.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
static Optional< spirv::StorageClass > getStorageClassForMemorySpace(unsigned space)
Returns the SPIR-V storage class given a memory space for memref.
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type)
static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, spirv::ScalarType type, Optional< spirv::StorageClass > storageClass={})
Converts a scalar type to a suitable type under the given targetEnv.
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
bool isa() const
Definition: Types.h:246
Version getVersion() const
static bool needsExplicitLayout(spirv::StorageClass storageClass)
Returns true if the given storageClass needs explicit layout when used in Shader environments.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
This class helps build Operations.
Definition: Builders.h:184
This class provides an abstraction over the different types of ranges over Values.
static LogicalResult checkCapabilityRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates)
Checks that candidatescapability requirements are possible to be satisfied with the given isAllowedFn...
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
MLIRContext * getContext() const
bool emulateNon32BitScalarTypes
Whether to emulate non-32-bit scalar types with 32-bit scalar types if no native support.
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
Type conversion from builtin types to SPIR-V types for shader interface.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult checkExtensionRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv...
U cast() const
Definition: Types.h:262
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:97