22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/MathExtras.h"
29 #define DEBUG_TYPE "mlir-spirv-conversion"
43 template <
typename LabelT>
47 for (
const auto &ors : candidates) {
53 for (spirv::Extension ext : ors)
54 extStrings.push_back(spirv::stringifyExtension(ext));
56 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
57 << llvm::join(extStrings,
", ")
58 <<
"] but none allowed in target environment\n";
71 template <
typename LabelT>
75 for (
const auto &ors : candidates) {
81 for (spirv::Capability cap : ors)
82 capStrings.push_back(spirv::stringifyCapability(cap));
84 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
85 << llvm::join(capStrings,
", ")
86 <<
"] but none allowed in target environment\n";
96 switch (storageClass) {
97 case spirv::StorageClass::PhysicalStorageBuffer:
98 case spirv::StorageClass::PushConstant:
99 case spirv::StorageClass::StorageBuffer:
100 case spirv::StorageClass::Uniform:
123 return cast<spirv::ScalarType>(
131 MLIRContext *SPIRVTypeConverter::getContext()
const {
132 return targetEnv.
getAttr().getContext();
136 return targetEnv.
allows(capability);
141 static std::optional<int64_t>
143 if (isa<spirv::ScalarType>(type)) {
156 if (
auto complexType = dyn_cast<ComplexType>(type)) {
160 return 2 * *elementSize;
163 if (
auto vecType = dyn_cast<VectorType>(type)) {
167 return vecType.getNumElements() * *elementSize;
170 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
175 if (!memRefType.hasStaticShape() ||
186 if (memRefType.getRank() == 0)
189 auto dims = memRefType.getShape();
190 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
191 ShapedType::isDynamic(offset) ||
192 llvm::is_contained(strides, ShapedType::kDynamic))
195 int64_t memrefSize = -1;
196 for (
const auto &shape :
enumerate(dims))
197 memrefSize =
std::max(memrefSize, shape.value() * strides[shape.index()]);
199 return (offset + memrefSize) * *elementSize;
202 if (
auto tensorType = dyn_cast<TensorType>(type)) {
203 if (!tensorType.hasStaticShape())
210 int64_t size = *elementSize;
211 for (
auto shape : tensorType.getShape())
225 std::optional<spirv::StorageClass> storageClass = {}) {
239 if (!
options.emulateLT32BitScalarTypes)
244 LLVM_DEBUG(llvm::dbgs()
246 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
250 if (
auto floatType = dyn_cast<FloatType>(type)) {
251 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
255 auto intType = cast<IntegerType>(type);
256 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
258 intType.getSignedness());
270 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
274 if (!llvm::isPowerOf2_32(type.getWidth())) {
275 LLVM_DEBUG(llvm::dbgs()
276 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
281 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
283 type.getSignedness());
292 Type indexType = dyn_cast<IndexType>(type.getElementType());
303 std::optional<spirv::StorageClass> storageClass = {}) {
305 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
309 auto intType = dyn_cast<IntegerType>(type.getElementType());
311 LLVM_DEBUG(llvm::dbgs()
313 <<
" illegal: cannot convert non-scalar element type\n");
318 if (type.getRank() <= 1 && type.getNumElements() == 1)
321 if (type.getNumElements() > 4) {
322 LLVM_DEBUG(llvm::dbgs()
323 << type <<
" illegal: > 4-element unimplemented\n");
330 if (type.getRank() <= 1 && type.getNumElements() == 1)
334 LLVM_DEBUG(llvm::dbgs()
335 << type <<
" illegal: not a valid composite type\n");
342 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
343 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
360 std::optional<spirv::StorageClass> storageClass = {}) {
361 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
363 LLVM_DEBUG(llvm::dbgs()
364 << type <<
" illegal: cannot convert non-scalar element type\n");
372 if (elementType != type.getElementType()) {
373 LLVM_DEBUG(llvm::dbgs()
374 << type <<
" illegal: complex type emulation unsupported\n");
391 if (!type.hasStaticShape()) {
392 LLVM_DEBUG(llvm::dbgs()
393 << type <<
" illegal: dynamic shape unimplemented\n");
398 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
400 LLVM_DEBUG(llvm::dbgs()
401 << type <<
" illegal: cannot convert non-scalar element type\n");
407 if (!scalarSize || !tensorSize) {
408 LLVM_DEBUG(llvm::dbgs()
409 << type <<
" illegal: cannot deduce element count\n");
413 int64_t arrayElemCount = *tensorSize / *scalarSize;
414 if (arrayElemCount == 0) {
415 LLVM_DEBUG(llvm::dbgs()
416 << type <<
" illegal: cannot handle zero-element tensors\n");
423 std::optional<int64_t> arrayElemSize =
425 if (!arrayElemSize) {
426 LLVM_DEBUG(llvm::dbgs()
427 << type <<
" illegal: cannot deduce converted element size\n");
437 spirv::StorageClass storageClass) {
438 unsigned numBoolBits =
options.boolNumBits;
439 if (numBoolBits != 8) {
440 LLVM_DEBUG(llvm::dbgs()
441 <<
"using non-8-bit storage for bool types unimplemented");
444 auto elementType = dyn_cast<spirv::ScalarType>(
452 std::optional<int64_t> arrayElemSize =
454 if (!arrayElemSize) {
455 LLVM_DEBUG(llvm::dbgs()
456 << type <<
" illegal: cannot deduce converted element size\n");
460 if (!type.hasStaticShape()) {
463 if (targetEnv.
allows(spirv::Capability::Kernel))
472 if (type.getNumElements() == 0) {
473 LLVM_DEBUG(llvm::dbgs()
474 << type <<
" illegal: zero-element memrefs are not supported\n");
478 int64_t memrefSize =
llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
482 if (targetEnv.
allows(spirv::Capability::Kernel))
490 spirv::StorageClass storageClass) {
491 IntegerType elementType = cast<IntegerType>(type.getElementType());
497 if (!type.hasStaticShape()) {
500 if (targetEnv.
allows(spirv::Capability::Kernel))
509 if (type.getNumElements() == 0) {
510 LLVM_DEBUG(llvm::dbgs()
511 << type <<
" illegal: zero-element memrefs are not supported\n");
520 if (targetEnv.
allows(spirv::Capability::Kernel))
528 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
533 <<
" illegal: expected memory space to be a SPIR-V storage class "
534 "attribute; please use MemorySpaceToStorageClassConverter to map "
535 "numeric memory spaces beforehand\n");
538 spirv::StorageClass storageClass = attr.getValue();
540 if (isa<IntegerType>(type.getElementType())) {
541 if (type.getElementTypeBitWidth() == 1)
543 if (type.getElementTypeBitWidth() < 8)
548 Type elementType = type.getElementType();
549 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
552 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
555 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
558 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
560 arrayElemType = type.getElementType();
565 <<
" unhandled: can only convert scalar or vector element type\n");
571 std::optional<int64_t> arrayElemSize =
573 if (!arrayElemSize) {
574 LLVM_DEBUG(llvm::dbgs()
575 << type <<
" illegal: cannot deduce converted element size\n");
579 if (!type.hasStaticShape()) {
582 if (targetEnv.
allows(spirv::Capability::Kernel))
593 LLVM_DEBUG(llvm::dbgs()
594 << type <<
" illegal: cannot deduce element count\n");
598 if (*memrefSize == 0) {
599 LLVM_DEBUG(llvm::dbgs()
600 << type <<
" illegal: zero-element memrefs are not supported\n");
607 if (targetEnv.
allows(spirv::Capability::Kernel))
633 if (inputs.size() != 1) {
634 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
637 Value input = inputs.front();
640 if (!isa<IntegerType>(type)) {
641 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
644 auto inputType = cast<IntegerType>(input.
getType());
646 auto scalarType = dyn_cast<spirv::ScalarType>(type);
648 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
655 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
656 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
662 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
663 return builder.
create<spirv::IEqualOp>(loc, input, one);
669 scalarType.getExtensions(exts);
670 scalarType.getCapabilities(caps);
673 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
681 return builder.
create<spirv::SConvertOp>(loc, type, input);
683 return builder.
create<spirv::UConvertOp>(loc, type, input);
707 addConversion([
this](IntegerType intType) -> std::optional<Type> {
708 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
710 if (intType.getWidth() < 8)
716 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
744 auto cast = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
745 return std::optional<Value>(cast.getResult(0));
761 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
767 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
769 auto fnType = funcOp.getFunctionType();
770 if (fnType.getNumResults() > 1)
774 for (
const auto &argType :
enumerate(fnType.getInputs())) {
775 auto convertedType = getTypeConverter()->convertType(argType.value());
778 signatureConverter.addInputs(argType.index(), convertedType);
782 if (fnType.getNumResults() == 1) {
783 resultType = getTypeConverter()->convertType(fnType.getResult(0));
789 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
790 funcOp.getLoc(), funcOp.getName(),
796 for (
const auto &namedAttr : funcOp->getAttrs()) {
797 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
799 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
805 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
813 patterns.
add<FuncOpConversion>(typeConverter, patterns.
getContext());
821 spirv::BuiltIn builtin) {
824 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
825 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
826 spirv::SPIRVDialect::getAttributeName(
827 spirv::Decoration::BuiltIn))) {
828 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
829 if (varBuiltIn && *varBuiltIn == builtin) {
840 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
844 static spirv::GlobalVariableOp
847 StringRef prefix, StringRef suffix) {
854 spirv::GlobalVariableOp newVarOp;
856 case spirv::BuiltIn::NumWorkgroups:
857 case spirv::BuiltIn::WorkgroupSize:
858 case spirv::BuiltIn::WorkgroupId:
859 case spirv::BuiltIn::LocalInvocationId:
860 case spirv::BuiltIn::GlobalInvocationId: {
862 spirv::StorageClass::Input);
865 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
868 case spirv::BuiltIn::SubgroupId:
869 case spirv::BuiltIn::NumSubgroups:
870 case spirv::BuiltIn::SubgroupSize: {
875 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
879 emitError(loc,
"unimplemented builtin variable generation for ")
880 << stringifyBuiltIn(builtin);
886 spirv::BuiltIn builtin,
888 StringRef prefix, StringRef suffix) {
891 op->
emitError(
"expected operation to be within a module-like op");
895 spirv::GlobalVariableOp varOp =
897 builtin, integerType, builder, prefix, suffix);
899 return builder.
create<spirv::LoadOp>(op->
getLoc(), ptr);
920 unsigned elementCount) {
921 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
922 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
929 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
930 auto numElements = cast<spirv::ArrayType>(
931 cast<spirv::StructType>(ptrType.getPointeeType())
934 if (numElements == elementCount)
943 static spirv::GlobalVariableOp
952 const char *name =
"__push_constant_var__";
953 return builder.
create<spirv::GlobalVariableOp>(loc, type, name,
958 unsigned offset,
Type integerType,
963 op->
emitError(
"expected operation to be within a module-like op");
968 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
971 Value offsetOp = builder.
create<spirv::ConstantOp>(
973 auto addrOp = builder.
create<spirv::AddressOfOp>(loc, varOp);
974 auto acOp = builder.
create<spirv::AccessChainOp>(
976 return builder.
create<spirv::LoadOp>(loc, acOp);
984 int64_t offset,
Type integerType,
986 assert(indices.size() == strides.size() &&
987 "must provide indices for all dimensions");
994 Value linearizedIndex = builder.
create<spirv::ConstantOp>(
997 Value strideVal = builder.
create<spirv::ConstantOp>(
1000 Value update = builder.
create<spirv::IMulOp>(loc, strideVal, index.value());
1002 builder.
create<spirv::IAddOp>(loc, linearizedIndex, update);
1004 return linearizedIndex;
1008 MemRefType baseType,
Value basePtr,
1016 llvm::is_contained(strides, ShapedType::kDynamic) ||
1017 ShapedType::isDynamic(offset)) {
1027 linearizedIndices.push_back(zero);
1029 if (baseType.getRank() == 0) {
1030 linearizedIndices.push_back(zero);
1032 linearizedIndices.push_back(
1033 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1035 return builder.
create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1039 MemRefType baseType,
Value basePtr,
1047 llvm::is_contained(strides, ShapedType::kDynamic) ||
1048 ShapedType::isDynamic(offset)) {
1056 if (baseType.getRank() == 0) {
1060 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1063 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1064 if (isa<spirv::ArrayType>(pointeeType)) {
1065 linearizedIndices.push_back(linearIndex);
1066 return builder.
create<spirv::AccessChainOp>(loc, basePtr,
1069 return builder.
create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1074 MemRefType baseType,
Value basePtr,
1078 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1091 std::unique_ptr<SPIRVConversionTarget>
1093 std::unique_ptr<SPIRVConversionTarget> target(
1097 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1100 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1107 bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1111 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1112 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1113 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1114 LLVM_DEBUG(llvm::dbgs()
1115 << op->
getName() <<
" illegal: requiring min version "
1116 << spirv::stringifyVersion(*minVersion) <<
"\n");
1120 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1121 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1122 if (maxVersion && *maxVersion < this->targetEnv.
getVersion()) {
1123 LLVM_DEBUG(llvm::dbgs()
1124 << op->
getName() <<
" illegal: requiring max version "
1125 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1133 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1135 extensions.getExtensions())))
1141 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1143 capabilities.getCapabilities())))
1151 if (llvm::any_of(valueTypes,
1152 [](
Type t) {
return !isa<spirv::SPIRVType>(t); }))
1157 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1158 valueTypes.push_back(globalVar.getType());
1164 for (
Type valueType : valueTypes) {
1165 typeExtensions.clear();
1166 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1171 typeCapabilities.clear();
1172 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool needsExplicitLayout(spirv::StorageClass storageClass)
Returns true if the given storageClass needs explicit layout when used in Shader environments.
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount)
Returns the push constant varible containing elementCount 32-bit integer values in body.
static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass)
static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, TensorType type)
Converts a tensor type to a suitable type under the given targetEnv.
static LogicalResult checkCapabilityRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates)
Checks that candidatescapability requirements are possible to be satisfied with the given isAllowedFn...
static std::optional< int64_t > getTypeNumBytes(const SPIRVConversionOptions &options, Type type)
static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, IntegerType type)
Converts a sub-byte integer type to i32 regardless of target environment.
static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin)
static ShapedType convertIndexElementType(ShapedType type, const SPIRVConversionOptions &options)
Returns a type with the same shape but with any index element type converted to the matching integer ...
static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType)
Gets or inserts a global variable for push constant storage containing elementCount 32-bit integer va...
static Type convertComplexType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, ComplexType type, std::optional< spirv::StorageClass > storageClass={})
static LogicalResult checkExtensionRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv.
std::optional< Value > castToSourceType(const spirv::TargetEnv &targetEnv, OpBuilder &builder, Type type, ValueRange inputs, Location loc)
Converts the given inputs to the original source type considering the targetEnv's capabilities.
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix, StringRef suffix)
Gets or inserts a global variable for a builtin within body block.
static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, StringRef suffix)
Gets name of global variable for a builtin.
static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, spirv::ScalarType type, std::optional< spirv::StorageClass > storageClass={})
Converts a scalar type to a suitable type under the given targetEnv.
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass)
static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional< spirv::StorageClass > storageClass={})
Converts a vector type to a suitable type under the given targetEnv.
static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass)
Wraps the given elementType in a struct and gets the pointer to the struct.
static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType)
Returns the pointer type for the push constant storage containing elementCount 32-bit integer values.
static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type)
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
Location getLoc()
The source location the operation was defined or derived from.
result_type_iterator result_type_end()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_iterator operand_type_begin()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.