31 #include "llvm/ADT/STLExtras.h" 
   32 #include "llvm/ADT/SmallVector.h" 
   33 #include "llvm/ADT/StringExtras.h" 
   34 #include "llvm/Support/Debug.h" 
   35 #include "llvm/Support/MathExtras.h" 
   39 #define DEBUG_TYPE "mlir-spirv-conversion" 
   49 static std::optional<SmallVector<int64_t>> 
getTargetShape(VectorType vecType) {
 
   50   LLVM_DEBUG(llvm::dbgs() << 
"Get target shape\n");
 
   51   if (vecType.isScalable()) {
 
   52     LLVM_DEBUG(llvm::dbgs()
 
   53                << 
"--scalable vectors are not supported -> BAIL\n");
 
   60     LLVM_DEBUG(llvm::dbgs() << 
"--no unrolling target shape defined\n");
 
   64   if (!maybeShapeRatio) {
 
   65     LLVM_DEBUG(llvm::dbgs()
 
   66                << 
"--could not compute integral shape ratio -> BAIL\n");
 
   69   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { 
return v == 1; })) {
 
   70     LLVM_DEBUG(llvm::dbgs() << 
"--no unrolling needed -> SKIP\n");
 
   73   LLVM_DEBUG(llvm::dbgs()
 
   74              << 
"--found an integral shape ratio to unroll to -> SUCCESS\n");
 
   84 template <
typename LabelT>
 
   85 static LogicalResult checkExtensionRequirements(
 
   88   for (
const auto &ors : candidates) {
 
   94       for (spirv::Extension ext : ors)
 
   95         extStrings.push_back(spirv::stringifyExtension(ext));
 
   97       llvm::dbgs() << label << 
" illegal: requires at least one extension in [" 
   98                    << llvm::join(extStrings, 
", ")
 
   99                    << 
"] but none allowed in target environment\n";
 
  112 template <
typename LabelT>
 
  113 static LogicalResult checkCapabilityRequirements(
 
  116   for (
const auto &ors : candidates) {
 
  117     if (targetEnv.
allows(ors))
 
  122       for (spirv::Capability cap : ors)
 
  123         capStrings.push_back(spirv::stringifyCapability(cap));
 
  125       llvm::dbgs() << label << 
" illegal: requires at least one capability in [" 
  126                    << llvm::join(capStrings, 
", ")
 
  127                    << 
"] but none allowed in target environment\n";
 
  136 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
 
  137   switch (storageClass) {
 
  138   case spirv::StorageClass::PhysicalStorageBuffer:
 
  139   case spirv::StorageClass::PushConstant:
 
  140   case spirv::StorageClass::StorageBuffer:
 
  141   case spirv::StorageClass::Uniform:
 
  151 wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
 
  152   auto structType = needsExplicitLayout(storageClass)
 
  164   return cast<spirv::ScalarType>(
 
  170 static std::optional<int64_t>
 
  172   if (isa<spirv::ScalarType>(type)) {
 
  186   if (
options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
 
  193   if (
auto complexType = dyn_cast<ComplexType>(type)) {
 
  194     auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
 
  197     return 2 * *elementSize;
 
  200   if (
auto vecType = dyn_cast<VectorType>(type)) {
 
  201     auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
 
  204     return vecType.getNumElements() * *elementSize;
 
  207   if (
auto memRefType = dyn_cast<MemRefType>(type)) {
 
  212     if (!memRefType.hasStaticShape() ||
 
  213         failed(memRefType.getStridesAndOffset(strides, offset)))
 
  219     auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
 
  223     if (memRefType.getRank() == 0)
 
  226     auto dims = memRefType.getShape();
 
  227     if (llvm::is_contained(dims, ShapedType::kDynamic) ||
 
  228         ShapedType::isDynamic(offset) ||
 
  229         llvm::is_contained(strides, ShapedType::kDynamic))
 
  232     int64_t memrefSize = -1;
 
  233     for (
const auto &shape : 
enumerate(dims))
 
  234       memrefSize = 
std::max(memrefSize, shape.value() * strides[shape.index()]);
 
  236     return (offset + memrefSize) * *elementSize;
 
  239   if (
auto tensorType = dyn_cast<TensorType>(type)) {
 
  240     if (!tensorType.hasStaticShape())
 
  243     auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
 
  247     int64_t size = *elementSize;
 
  248     for (
auto shape : tensorType.getShape())
 
  262                   std::optional<spirv::StorageClass> storageClass = {}) {
 
  266   type.getExtensions(extensions, storageClass);
 
  267   type.getCapabilities(capabilities, storageClass);
 
  270   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
 
  271       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
 
  276   if (!
options.emulateLT32BitScalarTypes)
 
  281     LLVM_DEBUG(llvm::dbgs()
 
  283                << 
" not converted to 32-bit for SPIR-V to avoid truncation\n");
 
  287   if (
auto floatType = dyn_cast<FloatType>(type)) {
 
  288     LLVM_DEBUG(llvm::dbgs() << type << 
" converted to 32-bit for SPIR-V\n");
 
  292   auto intType = cast<IntegerType>(type);
 
  293   LLVM_DEBUG(llvm::dbgs() << type << 
" converted to 32-bit for SPIR-V\n");
 
  295                           intType.getSignedness());
 
  308   if (type.getWidth() > 8) {
 
  309     LLVM_DEBUG(llvm::dbgs() << 
"not a subbyte type\n");
 
  313     LLVM_DEBUG(llvm::dbgs() << 
"unsupported sub-byte storage kind\n");
 
  317   if (!llvm::isPowerOf2_32(type.getWidth())) {
 
  318     LLVM_DEBUG(llvm::dbgs()
 
  319                << 
"unsupported non-power-of-two bitwidth in sub-byte" << type
 
  324   LLVM_DEBUG(llvm::dbgs() << type << 
" converted to 32-bit for SPIR-V\n");
 
  326                           type.getSignedness());
 
  333   if (!
options.emulateUnsupportedFloatTypes)
 
  336   if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
 
  337           Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
 
  338           Float8E8M0FNUType>(type))
 
  340   LLVM_DEBUG(llvm::dbgs() << 
"unsupported 8-bit float type: " << type << 
"\n");
 
  348 convertShaped8BitFloatType(ShapedType type,
 
  350   if (!
options.emulateUnsupportedFloatTypes)
 
  352   Type srcElementType = type.getElementType();
 
  353   Type convertedElementType = 
nullptr;
 
  355   if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
 
  356           Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
 
  357           Float8E8M0FNUType>(srcElementType))
 
  361   if (!convertedElementType)
 
  364   return type.clone(convertedElementType);
 
  371 convertIndexElementType(ShapedType type,
 
  373   Type indexType = dyn_cast<IndexType>(type.getElementType());
 
  384                   std::optional<spirv::StorageClass> storageClass = {}) {
 
  385   type = cast<VectorType>(convertIndexElementType(type, 
options));
 
  386   type = cast<VectorType>(convertShaped8BitFloatType(type, 
options));
 
  387   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
 
  391     auto intType = dyn_cast<IntegerType>(type.getElementType());
 
  393       LLVM_DEBUG(llvm::dbgs()
 
  395                  << 
" illegal: cannot convert non-scalar element type\n");
 
  399     Type elementType = convertSubByteIntegerType(
options, intType);
 
  403     if (type.getRank() <= 1 && type.getNumElements() == 1)
 
  406     if (type.getNumElements() > 4) {
 
  407       LLVM_DEBUG(llvm::dbgs()
 
  408                  << type << 
" illegal: > 4-element unimplemented\n");
 
  415   if (type.getRank() <= 1 && type.getNumElements() == 1)
 
  416     return convertScalarType(targetEnv, 
options, scalarType, storageClass);
 
  419     LLVM_DEBUG(llvm::dbgs()
 
  420                << type << 
" illegal: not a valid composite type\n");
 
  427   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
 
  428   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
 
  431   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
 
  432       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
 
  436       convertScalarType(targetEnv, 
options, scalarType, storageClass);
 
  445                    std::optional<spirv::StorageClass> storageClass = {}) {
 
  446   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
 
  448     LLVM_DEBUG(llvm::dbgs()
 
  449                << type << 
" illegal: cannot convert non-scalar element type\n");
 
  454       convertScalarType(targetEnv, 
options, scalarType, storageClass);
 
  457   if (elementType != type.getElementType()) {
 
  458     LLVM_DEBUG(llvm::dbgs()
 
  459                << type << 
" illegal: complex type emulation unsupported\n");
 
  476   if (!type.hasStaticShape()) {
 
  477     LLVM_DEBUG(llvm::dbgs()
 
  478                << type << 
" illegal: dynamic shape unimplemented\n");
 
  482   type = cast<TensorType>(convertIndexElementType(type, 
options));
 
  483   type = cast<TensorType>(convertShaped8BitFloatType(type, 
options));
 
  484   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
 
  486     LLVM_DEBUG(llvm::dbgs()
 
  487                << type << 
" illegal: cannot convert non-scalar element type\n");
 
  491   std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
 
  492   std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
 
  493   if (!scalarSize || !tensorSize) {
 
  494     LLVM_DEBUG(llvm::dbgs()
 
  495                << type << 
" illegal: cannot deduce element count\n");
 
  499   int64_t arrayElemCount = *tensorSize / *scalarSize;
 
  500   if (arrayElemCount == 0) {
 
  501     LLVM_DEBUG(llvm::dbgs()
 
  502                << type << 
" illegal: cannot handle zero-element tensors\n");
 
  506   Type arrayElemType = convertScalarType(targetEnv, 
options, scalarType);
 
  509   std::optional<int64_t> arrayElemSize =
 
  510       getTypeNumBytes(
options, arrayElemType);
 
  511   if (!arrayElemSize) {
 
  512     LLVM_DEBUG(llvm::dbgs()
 
  513                << type << 
" illegal: cannot deduce converted element size\n");
 
  523                                   spirv::StorageClass storageClass) {
 
  524   unsigned numBoolBits = 
options.boolNumBits;
 
  525   if (numBoolBits != 8) {
 
  526     LLVM_DEBUG(llvm::dbgs()
 
  527                << 
"using non-8-bit storage for bool types unimplemented");
 
  530   auto elementType = dyn_cast<spirv::ScalarType>(
 
  535       convertScalarType(targetEnv, 
options, elementType, storageClass);
 
  538   std::optional<int64_t> arrayElemSize =
 
  539       getTypeNumBytes(
options, arrayElemType);
 
  540   if (!arrayElemSize) {
 
  541     LLVM_DEBUG(llvm::dbgs()
 
  542                << type << 
" illegal: cannot deduce converted element size\n");
 
  546   if (!type.hasStaticShape()) {
 
  549     if (targetEnv.
allows(spirv::Capability::Kernel))
 
  551     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
 
  555     return wrapInStructAndGetPointer(arrayType, storageClass);
 
  558   if (type.getNumElements() == 0) {
 
  559     LLVM_DEBUG(llvm::dbgs()
 
  560                << type << 
" illegal: zero-element memrefs are not supported\n");
 
  564   int64_t memrefSize = 
llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
 
  566   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
 
  568   if (targetEnv.
allows(spirv::Capability::Kernel))
 
  570   return wrapInStructAndGetPointer(arrayType, storageClass);
 
  576                                      spirv::StorageClass storageClass) {
 
  577   IntegerType elementType = cast<IntegerType>(type.getElementType());
 
  578   Type arrayElemType = convertSubByteIntegerType(
options, elementType);
 
  581   int64_t arrayElemSize = *getTypeNumBytes(
options, arrayElemType);
 
  583   if (!type.hasStaticShape()) {
 
  586     if (targetEnv.
allows(spirv::Capability::Kernel))
 
  588     int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
 
  592     return wrapInStructAndGetPointer(arrayType, storageClass);
 
  595   if (type.getNumElements() == 0) {
 
  596     LLVM_DEBUG(llvm::dbgs()
 
  597                << type << 
" illegal: zero-element memrefs are not supported\n");
 
  604   int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
 
  606   if (targetEnv.
allows(spirv::Capability::Kernel))
 
  608   return wrapInStructAndGetPointer(arrayType, storageClass);
 
  611 static spirv::Dim convertRank(int64_t rank) {
 
  614     return spirv::Dim::Dim1D;
 
  616     return spirv::Dim::Dim2D;
 
  618     return spirv::Dim::Dim3D;
 
  620     llvm_unreachable(
"Invalid memref rank!");
 
  624 static spirv::ImageFormat getImageFormat(
Type elementType) {
 
  626       .Case<Float16Type>([](Float16Type) { 
return spirv::ImageFormat::R16f; })
 
  627       .Case<Float32Type>([](Float32Type) { 
return spirv::ImageFormat::R32f; })
 
  628       .Case<IntegerType>([](IntegerType intType) {
 
  629         auto const isSigned = intType.isSigned() || intType.isSignless();
 
  630 #define BIT_WIDTH_CASE(BIT_WIDTH)                                              \ 
  632     return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i                      \ 
  633                     : spirv::ImageFormat::R##BIT_WIDTH##ui 
  635         switch (intType.getWidth()) {
 
  639           llvm_unreachable(
"Unhandled integer type!");
 
  642       .DefaultUnreachable(
"Unhandled element type!");
 
  643 #undef BIT_WIDTH_CASE 
  649   auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
 
  654         << 
" illegal: expected memory space to be a SPIR-V storage class " 
  655            "attribute; please use MemorySpaceToStorageClassConverter to map " 
  656            "numeric memory spaces beforehand\n");
 
  659   spirv::StorageClass storageClass = attr.getValue();
 
  664   if (storageClass == spirv::StorageClass::Image) {
 
  665     const int64_t rank = type.getRank();
 
  666     if (rank < 1 || rank > 3) {
 
  667       LLVM_DEBUG(llvm::dbgs()
 
  668                  << type << 
" illegal: cannot lower memref of rank " << rank
 
  669                  << 
" to a SPIR-V Image\n");
 
  675     auto elementType = type.getElementType();
 
  676     if (!isa<spirv::ScalarType>(elementType)) {
 
  677       LLVM_DEBUG(llvm::dbgs() << type << 
" illegal: cannot lower memref of " 
  678                               << elementType << 
" to a  SPIR-V Image\n");
 
  686         elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
 
  687         spirv::ImageArrayedInfo::NonArrayed,
 
  688         spirv::ImageSamplingInfo::SingleSampled,
 
  689         spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
 
  692         spvSampledImageType, spirv::StorageClass::UniformConstant);
 
  696   if (isa<IntegerType>(type.getElementType())) {
 
  697     if (type.getElementTypeBitWidth() == 1)
 
  698       return convertBoolMemrefType(targetEnv, 
options, type, storageClass);
 
  699     if (type.getElementTypeBitWidth() < 8)
 
  700       return convertSubByteMemrefType(targetEnv, 
options, type, storageClass);
 
  704   Type elementType = type.getElementType();
 
  705   if (
auto vecType = dyn_cast<VectorType>(elementType)) {
 
  707         convertVectorType(targetEnv, 
options, vecType, storageClass);
 
  708   } 
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
 
  710         convertComplexType(targetEnv, 
options, complexType, storageClass);
 
  711   } 
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
 
  713         convertScalarType(targetEnv, 
options, scalarType, storageClass);
 
  714   } 
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
 
  715     type = cast<MemRefType>(convertIndexElementType(type, 
options));
 
  716     arrayElemType = type.getElementType();
 
  717   } 
else if (
auto floatType = dyn_cast<FloatType>(elementType)) {
 
  719     type = cast<MemRefType>(convertShaped8BitFloatType(type, 
options));
 
  720     arrayElemType = type.getElementType();
 
  725         << 
" unhandled: can only convert scalar or vector element type\n");
 
  731   std::optional<int64_t> arrayElemSize =
 
  732       getTypeNumBytes(
options, arrayElemType);
 
  733   if (!arrayElemSize) {
 
  734     LLVM_DEBUG(llvm::dbgs()
 
  735                << type << 
" illegal: cannot deduce converted element size\n");
 
  739   if (!type.hasStaticShape()) {
 
  742     if (targetEnv.
allows(spirv::Capability::Kernel))
 
  744     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
 
  748     return wrapInStructAndGetPointer(arrayType, storageClass);
 
  751   std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
 
  753     LLVM_DEBUG(llvm::dbgs()
 
  754                << type << 
" illegal: cannot deduce element count\n");
 
  758   if (*memrefSize == 0) {
 
  759     LLVM_DEBUG(llvm::dbgs()
 
  760                << type << 
" illegal: zero-element memrefs are not supported\n");
 
  765   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
 
  767   if (targetEnv.
allows(spirv::Capability::Kernel))
 
  769   return wrapInStructAndGetPointer(arrayType, storageClass);
 
  793   if (inputs.size() != 1) {
 
  795         UnrealizedConversionCastOp::create(builder, loc, type, inputs);
 
  796     return castOp.getResult(0);
 
  798   Value input = inputs.front();
 
  801   if (!isa<IntegerType>(type)) {
 
  803         UnrealizedConversionCastOp::create(builder, loc, type, inputs);
 
  804     return castOp.getResult(0);
 
  806   auto inputType = cast<IntegerType>(input.
getType());
 
  808   auto scalarType = dyn_cast<spirv::ScalarType>(type);
 
  811         UnrealizedConversionCastOp::create(builder, loc, type, inputs);
 
  812     return castOp.getResult(0);
 
  818   if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
 
  820         UnrealizedConversionCastOp::create(builder, loc, type, inputs);
 
  821     return castOp.getResult(0);
 
  826     Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
 
  827     return spirv::IEqualOp::create(builder, loc, input, one);
 
  833   scalarType.getExtensions(exts);
 
  834   scalarType.getCapabilities(caps);
 
  835   if (
failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
 
  836       failed(checkExtensionRequirements(type, targetEnv, exts))) {
 
  838         UnrealizedConversionCastOp::create(builder, loc, type, inputs);
 
  839     return castOp.getResult(0);
 
  846     return spirv::SConvertOp::create(builder, loc, type, input);
 
  848   return spirv::UConvertOp::create(builder, loc, type, input);
 
  855 static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
 
  856                                                   spirv::BuiltIn builtin) {
 
  859   for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
 
  860     if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
 
  861             spirv::SPIRVDialect::getAttributeName(
 
  862                 spirv::Decoration::BuiltIn))) {
 
  863       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
 
  864       if (varBuiltIn == builtin) {
 
  873 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
 
  875   return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
 
  879 static spirv::GlobalVariableOp
 
  880 getOrInsertBuiltinVariable(
Block &body, 
Location loc, spirv::BuiltIn builtin,
 
  882                            StringRef prefix, StringRef suffix) {
 
  883   if (
auto varOp = getBuiltinVariable(body, builtin))
 
  889   spirv::GlobalVariableOp newVarOp;
 
  891   case spirv::BuiltIn::NumWorkgroups:
 
  892   case spirv::BuiltIn::WorkgroupSize:
 
  893   case spirv::BuiltIn::WorkgroupId:
 
  894   case spirv::BuiltIn::LocalInvocationId:
 
  895   case spirv::BuiltIn::GlobalInvocationId: {
 
  897                                            spirv::StorageClass::Input);
 
  898     std::string name = getBuiltinVarName(builtin, prefix, suffix);
 
  900         spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
 
  903   case spirv::BuiltIn::SubgroupId:
 
  904   case spirv::BuiltIn::NumSubgroups:
 
  905   case spirv::BuiltIn::SubgroupSize:
 
  906   case spirv::BuiltIn::SubgroupLocalInvocationId: {
 
  909     std::string name = getBuiltinVarName(builtin, prefix, suffix);
 
  911         spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
 
  915     emitError(loc, 
"unimplemented builtin variable generation for ")
 
  916         << stringifyBuiltIn(builtin);
 
  938 static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
 
  939                                                        unsigned elementCount) {
 
  940   for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
 
  941     auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
 
  948     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
 
  949       auto numElements = cast<spirv::ArrayType>(
 
  950                              cast<spirv::StructType>(ptrType.getPointeeType())
 
  953       if (numElements == elementCount)
 
  962 static spirv::GlobalVariableOp
 
  966   if (
auto varOp = getPushConstantVariable(block, elementCount))
 
  970   auto type = getPushConstantStorageType(elementCount, builder, indexType);
 
  971   const char *name = 
"__push_constant_var__";
 
  972   return spirv::GlobalVariableOp::create(builder, loc, type, name,
 
  986   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
 
  988     FunctionType fnType = funcOp.getFunctionType();
 
  989     if (fnType.getNumResults() > 1)
 
  993         fnType.getNumInputs());
 
  994     for (
const auto &argType : 
enumerate(fnType.getInputs())) {
 
  995       auto convertedType = getTypeConverter()->convertType(argType.value());
 
  998       signatureConverter.
addInputs(argType.index(), convertedType);
 
 1002     if (fnType.getNumResults() == 1) {
 
 1003       resultType = getTypeConverter()->convertType(fnType.getResult(0));
 
 1009     auto newFuncOp = spirv::FuncOp::create(
 
 1010         rewriter, funcOp.getLoc(), funcOp.getName(),
 
 1016     for (
const auto &namedAttr : funcOp->getAttrs()) {
 
 1017       if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
 
 1019         newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
 
 1025             &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
 
 1037   LogicalResult matchAndRewrite(func::FuncOp funcOp,
 
 1039     FunctionType fnType = funcOp.getFunctionType();
 
 1042     if (funcOp.isDeclaration()) {
 
 1043       LLVM_DEBUG(llvm::dbgs()
 
 1044                  << fnType << 
" illegal: declarations are unsupported\n");
 
 1049     auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
 
 1050                                           funcOp.getName(), fnType);
 
 1054     Location loc = newFuncOp.getBody().getLoc();
 
 1056     Block &entryBlock = newFuncOp.getBlocks().
front();
 
 1061         fnType.getInputs().size());
 
 1067     size_t newInputNo = 0;
 
 1073     llvm::SmallDenseMap<Operation *, size_t> tmpOps;
 
 1076     size_t newOpCount = 0;
 
 1079     for (
auto [origInputNo, origType] : 
enumerate(fnType.getInputs())) {
 
 1081       auto origVecType = dyn_cast<VectorType>(origType);
 
 1084         Value result = arith::ConstantOp::create(
 
 1085             rewriter, loc, origType, rewriter.
getZeroAttr(origType));
 
 1088         oneToNTypeMapping.
addInputs(origInputNo, origType);
 
 1097         Value result = arith::ConstantOp::create(
 
 1098             rewriter, loc, origType, rewriter.
getZeroAttr(origType));
 
 1101         oneToNTypeMapping.
addInputs(origInputNo, origType);
 
 1106       VectorType unrolledType =
 
 1108       auto originalShape =
 
 1109           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
 
 1112       Value result = arith::ConstantOp::create(
 
 1113           rewriter, loc, origVecType, rewriter.
getZeroAttr(origVecType));
 
 1116       Value dummy = arith::ConstantOp::create(
 
 1117           rewriter, loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
 
 1125         result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
 
 1126                                                       result, offsets, strides);
 
 1127         newTypes.push_back(unrolledType);
 
 1128         unrolledInputNums.push_back(newInputNo);
 
 1133       oneToNTypeMapping.
addInputs(origInputNo, newTypes);
 
 1138     auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
 
 1140                              [&] { newFuncOp.setFunctionType(newFnType); });
 
 1149     for (
auto &[placeholderOp, argIdx] : tmpOps) {
 
 1152       Value replacement = newFuncOp.getArgument(argIdx);
 
 1160     size_t unrolledInputIdx = 0;
 
 1166       if (count >= newOpCount)
 
 1168       if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
 
 1169         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
 
 1171           curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
 
 1193   LogicalResult matchAndRewrite(func::ReturnOp returnOp,
 
 1196     auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
 
 1200     FunctionType fnType = funcOp.getFunctionType();
 
 1202         fnType.getResults().size());
 
 1209     for (
auto [origResultNo, origType] : 
enumerate(fnType.getResults())) {
 
 1211       auto origVecType = dyn_cast<VectorType>(origType);
 
 1213         oneToNTypeMapping.
addInputs(origResultNo, origType);
 
 1214         newOperands.push_back(returnOp.getOperand(origResultNo));
 
 1221         oneToNTypeMapping.
addInputs(origResultNo, origType);
 
 1222         newOperands.push_back(returnOp.getOperand(origResultNo));
 
 1225       VectorType unrolledType =
 
 1230       auto originalShape =
 
 1231           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
 
 1234       extractShape.back() = targetShape->back();
 
 1236       Value returnValue = returnOp.getOperand(origResultNo);
 
 1239         Value result = vector::ExtractStridedSliceOp::create(
 
 1240             rewriter, loc, returnValue, offsets, extractShape, strides);
 
 1241         if (originalShape.size() > 1) {
 
 1244               vector::ExtractOp::create(rewriter, loc, result, extractIndices);
 
 1246         newOperands.push_back(result);
 
 1247         newTypes.push_back(unrolledType);
 
 1249       oneToNTypeMapping.
addInputs(origResultNo, newTypes);
 
 1257                              [&] { funcOp.setFunctionType(newFnType); });
 
 1262                        func::ReturnOp::create(rewriter, loc, newOperands));
 
 1275                                            spirv::BuiltIn builtin,
 
 1277                                            StringRef prefix, StringRef suffix) {
 
 1280     op->
emitError(
"expected operation to be within a module-like op");
 
 1284   spirv::GlobalVariableOp varOp =
 
 1286                                  builtin, integerType, builder, prefix, suffix);
 
 1287   Value ptr = spirv::AddressOfOp::create(builder, op->
getLoc(), varOp);
 
 1288   return spirv::LoadOp::create(builder, op->
getLoc(), ptr);
 
 1296                                   unsigned offset, 
Type integerType,
 
 1301     op->
emitError(
"expected operation to be within a module-like op");
 
 1305   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
 
 1306       loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
 
 1309   Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
 
 1311   auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
 
 1312   auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
 
 1314   return spirv::LoadOp::create(builder, loc, acOp);
 
 1322                                   int64_t offset, 
Type integerType,
 
 1324   assert(indices.size() == strides.size() &&
 
 1325          "must provide indices for all dimensions");
 
 1339         builder.
createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
 
 1341         builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
 
 1343   return linearizedIndex;
 
 1347                                        MemRefType baseType, 
Value basePtr,
 
 1354   if (
failed(baseType.getStridesAndOffset(strides, offset)) ||
 
 1355       llvm::is_contained(strides, ShapedType::kDynamic) ||
 
 1356       ShapedType::isDynamic(offset)) {
 
 1366   linearizedIndices.push_back(zero);
 
 1368   if (baseType.getRank() == 0) {
 
 1369     linearizedIndices.push_back(zero);
 
 1371     linearizedIndices.push_back(
 
 1372         linearizeIndex(indices, strides, offset, indexType, loc, builder));
 
 1374   return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
 
 1378                                        MemRefType baseType, 
Value basePtr,
 
 1385   if (
failed(baseType.getStridesAndOffset(strides, offset)) ||
 
 1386       llvm::is_contained(strides, ShapedType::kDynamic) ||
 
 1387       ShapedType::isDynamic(offset)) {
 
 1395   if (baseType.getRank() == 0) {
 
 1399         linearizeIndex(indices, strides, offset, indexType, loc, builder);
 
 1402       cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
 
 1403   if (isa<spirv::ArrayType>(pointeeType)) {
 
 1404     linearizedIndices.push_back(linearIndex);
 
 1405     return spirv::AccessChainOp::create(builder, loc, basePtr,
 
 1408   return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
 
 1413                                  MemRefType baseType, 
Value basePtr,
 
 1417   if (typeConverter.
allows(spirv::Capability::Kernel)) {
 
 1431   for (
int i : {4, 3, 2}) {
 
 1440   VectorType srcVectorType = op.getSourceVectorType();
 
 1441   assert(srcVectorType.getRank() == 1); 
 
 1442   int64_t vectorSize =
 
 1444   return {vectorSize};
 
 1449   VectorType vectorType = op.getResultVectorType();
 
 1456 std::optional<SmallVector<int64_t>>
 
 1459     if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
 
 1468       .Case<vector::ReductionOp, vector::TransposeOp>(
 
 1470       .Default(std::nullopt);
 
 1504         patterns, vector::VectorTransposeLowering::EltWise);
 
 1515     vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns);
 
 1516     vector::ReductionOp::getCanonicalizationPatterns(
patterns, context);
 
 1517     vector::TransposeOp::getCanonicalizationPatterns(
patterns, context);
 
 1521     vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
 
 1523     vector::InsertOp::getCanonicalizationPatterns(
patterns, context);
 
 1524     vector::ExtractOp::getCanonicalizationPatterns(
patterns, context);
 
 1528     vector::BroadcastOp::getCanonicalizationPatterns(
patterns, context);
 
 1529     vector::ShapeCastOp::getCanonicalizationPatterns(
patterns, context);
 
 1558   addConversion([
this](IntegerType intType) -> std::optional<Type> {
 
 1559     if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
 
 1560       return convertScalarType(this->targetEnv, this->options, scalarType);
 
 1561     if (intType.getWidth() < 8)
 
 1562       return convertSubByteIntegerType(this->options, intType);
 
 1566   addConversion([
this](FloatType floatType) -> std::optional<Type> {
 
 1567     if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
 
 1568       return convertScalarType(this->targetEnv, this->options, scalarType);
 
 1569     if (floatType.getWidth() == 8)
 
 1570       return convert8BitFloatType(this->options, floatType);
 
 1575     return convertComplexType(this->targetEnv, this->options, complexType);
 
 1579     return convertVectorType(this->targetEnv, this->options, vectorType);
 
 1583     return convertTensorType(this->targetEnv, this->options, tensorType);
 
 1587     return convertMemrefType(this->targetEnv, this->options, memRefType);
 
 1593         return castToSourceType(this->targetEnv, builder, type, inputs, loc);
 
 1597     auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
 
 1598     return cast.getResult(0);
 
 1603   return ::getIndexType(getContext(), options);
 
 1606 MLIRContext *SPIRVTypeConverter::getContext()
 const {
 
 1607   return targetEnv.
getAttr().getContext();
 
 1611   return targetEnv.
allows(capability);
 
 1618 std::unique_ptr<SPIRVConversionTarget>
 
 1620   std::unique_ptr<SPIRVConversionTarget> target(
 
 1624   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
 
 1627       [targetPtr](
Operation *op) { 
return targetPtr->isLegalOp(op); });
 
 1634 bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
 
 1638   if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
 
 1639     std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
 
 1640     if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
 
 1641       LLVM_DEBUG(llvm::dbgs()
 
 1642                  << op->
getName() << 
" illegal: requiring min version " 
 1647   if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
 
 1648     std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
 
 1649     if (maxVersion && *maxVersion < this->targetEnv.
getVersion()) {
 
 1650       LLVM_DEBUG(llvm::dbgs()
 
 1651                  << op->
getName() << 
" illegal: requiring max version " 
 1660   if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
 
 1661     if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
 
 1662                                           extensions.getExtensions())))
 
 1668   if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
 
 1669     if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
 
 1670                                            capabilities.getCapabilities())))
 
 1678   if (llvm::any_of(valueTypes,
 
 1679                    [](
Type t) { 
return !isa<spirv::SPIRVType>(t); }))
 
 1684   if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
 
 1685     valueTypes.push_back(globalVar.getType());
 
 1691   for (
Type valueType : valueTypes) {
 
 1692     typeExtensions.clear();
 
 1693     cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
 
 1694     if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
 
 1698     typeCapabilities.clear();
 
 1699     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
 
 1700     if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
 
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_iterator result_type_end()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
llvm::SmallString< 4 > stringifyVersion(TosaSpecificationVersion version)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)