28#include "llvm/ADT/ArrayRef.h" 
   29#include "llvm/ADT/STLExtras.h" 
   30#include "llvm/ADT/SmallVector.h" 
   31#include "llvm/ADT/SmallVectorExtras.h" 
   32#include "llvm/Support/FormatVariadic.h" 
   43  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 
 
   51  if (
auto vectorType = dyn_cast<VectorType>(type))
 
   52    return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
 
 
   58struct VectorShapeCast final : 
public OpConversionPattern<vector::ShapeCastOp> {
 
   62  matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
 
   63                  ConversionPatternRewriter &rewriter)
 const override {
 
   64    Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
 
   70    if (dstType == adaptor.getSource().getType() ||
 
   71        shapeCastOp.getResultVectorType().getNumElements() == 1) {
 
   72      rewriter.replaceOp(shapeCastOp, adaptor.getSource());
 
   81struct VectorBitcastConvert final
 
   82    : 
public OpConversionPattern<vector::BitCastOp> {
 
   86  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
 
   87                  ConversionPatternRewriter &rewriter)
 const override {
 
   88    Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
 
   92    if (dstType == adaptor.getSource().getType()) {
 
   93      rewriter.replaceOp(bitcastOp, adaptor.getSource());
 
  100    Type srcType = adaptor.getSource().getType();
 
  102      return rewriter.notifyMatchFailure(
 
  104          llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
 
  108    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
 
  109                                                  adaptor.getSource());
 
  114struct VectorBroadcastConvert final
 
  115    : 
public OpConversionPattern<vector::BroadcastOp> {
 
  119  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
 
  120                  ConversionPatternRewriter &rewriter)
 const override {
 
  122        getTypeConverter()->convertType(castOp.getResultVectorType());
 
  126    if (isa<spirv::ScalarType>(resultType)) {
 
  127      rewriter.replaceOp(castOp, adaptor.getSource());
 
  131    SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
 
  132                                 adaptor.getSource());
 
  133    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
 
  145static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
 
  147                                  int64_t kPoisonIndex, 
unsigned vectorSize) {
 
  148  if (llvm::isPowerOf2_32(vectorSize)) {
 
  149    Value inBoundsMask = spirv::ConstantOp::create(
 
  150        rewriter, loc, dynamicIndex.
getType(),
 
  151        rewriter.getIntegerAttr(dynamicIndex.
getType(), vectorSize - 1));
 
  152    return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
 
  155  Value poisonIndex = spirv::ConstantOp::create(
 
  156      rewriter, loc, dynamicIndex.
getType(),
 
  157      rewriter.getIntegerAttr(dynamicIndex.
getType(), kPoisonIndex));
 
  159      spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
 
  160  return spirv::SelectOp::create(
 
  161      rewriter, loc, cmpResult,
 
  162      spirv::ConstantOp::getZero(dynamicIndex.
getType(), loc, rewriter),
 
  166struct VectorExtractOpConvert final
 
  167    : 
public OpConversionPattern<vector::ExtractOp> {
 
  171  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
 
  172                  ConversionPatternRewriter &rewriter)
 const override {
 
  173    Type dstType = getTypeConverter()->convertType(extractOp.getType());
 
  177    if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
 
  178      rewriter.replaceOp(extractOp, adaptor.getSource());
 
  182    if (std::optional<int64_t> 
id =
 
  184      if (
id == vector::ExtractOp::kPoisonIndex)
 
  185        return rewriter.notifyMatchFailure(
 
  187            "Static use of poison index handled elsewhere (folded to poison)");
 
  188      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
 
  189          extractOp, dstType, adaptor.getSource(),
 
  190          rewriter.getI32ArrayAttr(
id.value()));
 
  192      Value sanitizedIndex = sanitizeDynamicIndex(
 
  193          rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
 
  194          vector::ExtractOp::kPoisonIndex,
 
  195          extractOp.getSourceVectorType().getNumElements());
 
  196      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
 
  197          extractOp, dstType, adaptor.getSource(), sanitizedIndex);
 
  203struct VectorExtractStridedSliceOpConvert final
 
  204    : 
public OpConversionPattern<vector::ExtractStridedSliceOp> {
 
  208  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
 
  209                  ConversionPatternRewriter &rewriter)
 const override {
 
  210    Type dstType = getTypeConverter()->convertType(extractOp.getType());
 
  220    Value srcVector = adaptor.getOperands().front();
 
  223    if (isa<spirv::ScalarType>(dstType)) {
 
  224      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
 
  229    SmallVector<int32_t, 2> 
indices(size);
 
  232    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
 
  233        extractOp, dstType, srcVector, srcVector,
 
  234        rewriter.getI32ArrayAttr(
indices));
 
  240template <
class SPIRVFMAOp>
 
  241struct VectorFmaOpConvert final : 
public OpConversionPattern<vector::FMAOp> {
 
  245  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
 
  246                  ConversionPatternRewriter &rewriter)
 const override {
 
  247    Type dstType = getTypeConverter()->convertType(fmaOp.getType());
 
  250    rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
 
  251                                            adaptor.getRhs(), adaptor.getAcc());
 
  256struct VectorFromElementsOpConvert final
 
  257    : 
public OpConversionPattern<vector::FromElementsOp> {
 
  261  matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
 
  262                  ConversionPatternRewriter &rewriter)
 const override {
 
  263    Type resultType = getTypeConverter()->convertType(op.getType());
 
  267    if (isa<spirv::ScalarType>(resultType)) {
 
  270      rewriter.replaceOp(op, elements[0]);
 
  275    assert(cast<VectorType>(resultType).getRank() == 1);
 
  276    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
 
  282struct VectorInsertOpConvert final
 
  283    : 
public OpConversionPattern<vector::InsertOp> {
 
  287  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
 
  288                  ConversionPatternRewriter &rewriter)
 const override {
 
  289    if (isa<VectorType>(insertOp.getValueToStoreType()))
 
  290      return rewriter.notifyMatchFailure(insertOp, 
"unsupported vector source");
 
  291    if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
 
  292      return rewriter.notifyMatchFailure(insertOp,
 
  293                                         "unsupported dest vector type");
 
  296    if (insertOp.getValueToStoreType().isIntOrFloat() &&
 
  297        insertOp.getDestVectorType().getNumElements() == 1) {
 
  298      rewriter.replaceOp(insertOp, adaptor.getValueToStore());
 
  302    if (std::optional<int64_t> 
id =
 
  304      if (
id == vector::InsertOp::kPoisonIndex)
 
  305        return rewriter.notifyMatchFailure(
 
  307            "Static use of poison index handled elsewhere (folded to poison)");
 
  308      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
 
  309          insertOp, adaptor.getValueToStore(), adaptor.getDest(), 
id.value());
 
  311      Value sanitizedIndex = sanitizeDynamicIndex(
 
  312          rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
 
  313          vector::InsertOp::kPoisonIndex,
 
  314          insertOp.getDestVectorType().getNumElements());
 
  315      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
 
  316          insertOp, insertOp.getDest(), adaptor.getValueToStore(),
 
  323struct VectorInsertStridedSliceOpConvert final
 
  324    : 
public OpConversionPattern<vector::InsertStridedSliceOp> {
 
  328  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
 
  329                  ConversionPatternRewriter &rewriter)
 const override {
 
  330    Value srcVector = adaptor.getOperands().front();
 
  331    Value dstVector = adaptor.getOperands().back();
 
  338    if (isa<spirv::ScalarType>(srcVector.
getType())) {
 
  339      assert(!isa<spirv::ScalarType>(dstVector.
getType()));
 
  340      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
 
  341          insertOp, dstVector.
getType(), srcVector, dstVector,
 
  342          rewriter.getI32ArrayAttr(offset));
 
  346    uint64_t totalSize = cast<VectorType>(dstVector.
getType()).getNumElements();
 
  347    uint64_t insertSize =
 
  348        cast<VectorType>(srcVector.
getType()).getNumElements();
 
  350    SmallVector<int32_t, 2> 
indices(totalSize);
 
  352    std::iota(
indices.begin() + offset, 
indices.begin() + offset + insertSize,
 
  355    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
 
  356        insertOp, dstVector.
getType(), dstVector, srcVector,
 
  357        rewriter.getI32ArrayAttr(
indices));
 
  364    vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
 
  365    VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
 
  366  int numElements = 
static_cast<int>(srcVectorType.getDimSize(0));
 
  368  values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
 
  371  for (
int i = 0; i < numElements; ++i) {
 
  372    values.push_back(spirv::CompositeExtractOp::create(
 
  373        rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
 
  374        rewriter.getI32ArrayAttr({i})));
 
  377    values.push_back(
acc);
 
  382struct ReductionRewriteInfo {
 
  384  SmallVector<Value> extractedElements;
 
  387FailureOr<ReductionRewriteInfo> 
static getReductionInfo(
 
  388    vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
 
  389    ConversionPatternRewriter &rewriter, 
const TypeConverter &typeConverter) {
 
  390  Type resultType = typeConverter.convertType(op.getType());
 
  394  auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
 
  395  if (!srcVectorType || srcVectorType.getRank() != 1)
 
  396    return rewriter.notifyMatchFailure(op, 
"not a 1-D vector source");
 
  399      extractAllElements(op, adaptor, srcVectorType, rewriter);
 
  401  return ReductionRewriteInfo{resultType, std::move(extractedElements)};
 
  404template <
typename SPIRVUMaxOp, 
typename SPIRVUMinOp, 
typename SPIRVSMaxOp,
 
  405          typename SPIRVSMinOp>
 
  406struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 
  410  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
 
  411                  ConversionPatternRewriter &rewriter)
 const override {
 
  413        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
 
  414    if (
failed(reductionInfo))
 
  417    auto [resultType, extractedElements] = *reductionInfo;
 
  418    Location loc = reduceOp->getLoc();
 
  419    Value 
result = extractedElements.front();
 
  420    for (Value next : llvm::drop_begin(extractedElements)) {
 
  421      switch (reduceOp.getKind()) {
 
  423#define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \ 
  424  case vector::CombiningKind::kind:                                            \ 
  425    if (llvm::isa<IntegerType>(resultType)) {                                  \ 
  426      result = spirv::iop::create(rewriter, loc, resultType, result, next);    \ 
  428      assert(llvm::isa<FloatType>(resultType));                                \ 
  429      result = spirv::fop::create(rewriter, loc, resultType, result, next);    \ 
  433#define INT_OR_FLOAT_CASE(kind, fop)                                           \ 
  434  case vector::CombiningKind::kind:                                            \ 
  435    result = fop::create(rewriter, loc, resultType, result, next);             \ 
  445      case vector::CombiningKind::AND:
 
  446      case vector::CombiningKind::OR:
 
  447      case vector::CombiningKind::XOR:
 
  448        return rewriter.notifyMatchFailure(reduceOp, 
"unimplemented");
 
  450        return rewriter.notifyMatchFailure(reduceOp, 
"not handled here");
 
  452#undef INT_AND_FLOAT_CASE 
  453#undef INT_OR_FLOAT_CASE 
  456    rewriter.replaceOp(reduceOp, 
result);
 
  461template <
typename SPIRVFMaxOp, 
typename SPIRVFMinOp>
 
  462struct VectorReductionFloatMinMax final
 
  463    : OpConversionPattern<vector::ReductionOp> {
 
  467  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
 
  468                  ConversionPatternRewriter &rewriter)
 const override {
 
  470        getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
 
  471    if (
failed(reductionInfo))
 
  474    auto [resultType, extractedElements] = *reductionInfo;
 
  475    Location loc = reduceOp->getLoc();
 
  476    Value 
result = extractedElements.front();
 
  477    for (Value next : llvm::drop_begin(extractedElements)) {
 
  478      switch (reduceOp.getKind()) {
 
  480#define INT_OR_FLOAT_CASE(kind, fop)                                           \ 
  481  case vector::CombiningKind::kind:                                            \ 
  482    result = fop::create(rewriter, loc, resultType, result, next);             \ 
  491        return rewriter.notifyMatchFailure(reduceOp, 
"not handled here");
 
  493#undef INT_OR_FLOAT_CASE 
  496    rewriter.replaceOp(reduceOp, 
result);
 
  501class VectorScalarBroadcastPattern final
 
  502    : 
public OpConversionPattern<vector::BroadcastOp> {
 
  507  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
 
  508                  ConversionPatternRewriter &rewriter)
 const override {
 
  509    if (isa<VectorType>(op.getSourceType())) {
 
  510      return rewriter.notifyMatchFailure(
 
  511          op, 
"only conversion of 'broadcast from scalar' is supported");
 
  513    Type dstType = getTypeConverter()->convertType(op.getType());
 
  516    if (isa<spirv::ScalarType>(dstType)) {
 
  517      rewriter.replaceOp(op, adaptor.getSource());
 
  519      auto dstVecType = cast<VectorType>(dstType);
 
  520      SmallVector<Value, 4> source(dstVecType.getNumElements(),
 
  521                                   adaptor.getSource());
 
  522      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
 
  529struct VectorShuffleOpConvert final
 
  530    : 
public OpConversionPattern<vector::ShuffleOp> {
 
  534  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
 
  535                  ConversionPatternRewriter &rewriter)
 const override {
 
  536    VectorType oldResultType = shuffleOp.getResultVectorType();
 
  537    Type newResultType = getTypeConverter()->convertType(oldResultType);
 
  539      return rewriter.notifyMatchFailure(shuffleOp,
 
  540                                         "unsupported result vector type");
 
  542    auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
 
  544    VectorType oldV1Type = shuffleOp.getV1VectorType();
 
  545    VectorType oldV2Type = shuffleOp.getV2VectorType();
 
  549    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
 
  550        oldResultType.getNumElements() > 1) {
 
  551      rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
 
  552          shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
 
  553          rewriter.getI32ArrayAttr(mask));
 
  560    auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
 
  561                               Value scalarOrVec, int32_t idx) -> Value {
 
  562      if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
 
  563        return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
 
  566      assert(idx == 0 && 
"Invalid scalar element index");
 
  570    int32_t numV1Elems = oldV1Type.getNumElements();
 
  571    SmallVector<Value> newOperands(mask.size());
 
  572    for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
 
  573      Value vec = adaptor.getV1();
 
  574      int32_t elementIdx = shuffleIdx;
 
  575      if (elementIdx >= numV1Elems) {
 
  576        vec = adaptor.getV2();
 
  577        elementIdx -= numV1Elems;
 
  580      newOperand = getElementAtIdx(vec, elementIdx);
 
  584    if (newOperands.size() == 1) {
 
  585      rewriter.replaceOp(shuffleOp, newOperands.front());
 
  589    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
 
  590        shuffleOp, newResultType, newOperands);
 
  595struct VectorInterleaveOpConvert final
 
  596    : 
public OpConversionPattern<vector::InterleaveOp> {
 
  600  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
 
  601                  ConversionPatternRewriter &rewriter)
 const override {
 
  603    VectorType oldResultType = interleaveOp.getResultVectorType();
 
  604    Type newResultType = getTypeConverter()->convertType(oldResultType);
 
  606      return rewriter.notifyMatchFailure(interleaveOp,
 
  607                                         "unsupported result vector type");
 
  610    VectorType sourceType = interleaveOp.getSourceVectorType();
 
  611    int n = sourceType.getNumElements();
 
  617      Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
 
  618      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
 
  619          interleaveOp, newResultType, newOperands);
 
  623    auto seq = llvm::seq<int64_t>(2 * n);
 
  624    auto indices = llvm::map_to_vector(
 
  625        seq, [n](
int i) { 
return (i % 2 ? n : 0) + i / 2; });
 
  628    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
 
  629        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
 
  630        rewriter.getI32ArrayAttr(
indices));
 
  636struct VectorDeinterleaveOpConvert final
 
  637    : 
public OpConversionPattern<vector::DeinterleaveOp> {
 
  641  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
 
  642                  ConversionPatternRewriter &rewriter)
 const override {
 
  645    VectorType oldResultType = deinterleaveOp.getResultVectorType();
 
  646    Type newResultType = getTypeConverter()->convertType(oldResultType);
 
  648      return rewriter.notifyMatchFailure(deinterleaveOp,
 
  649                                         "unsupported result vector type");
 
  651    Location loc = deinterleaveOp->getLoc();
 
  654    Value sourceVector = adaptor.getSource();
 
  655    VectorType sourceType = deinterleaveOp.getSourceVectorType();
 
  656    int n = sourceType.getNumElements();
 
  662      auto elem0 = spirv::CompositeExtractOp::create(
 
  663          rewriter, loc, newResultType, sourceVector,
 
  664          rewriter.getI32ArrayAttr({0}));
 
  666      auto elem1 = spirv::CompositeExtractOp::create(
 
  667          rewriter, loc, newResultType, sourceVector,
 
  668          rewriter.getI32ArrayAttr({1}));
 
  670      rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
 
  675    auto seqEven = llvm::seq<int64_t>(n / 2);
 
  677        llvm::map_to_vector(seqEven, [](
int i) { 
return i * 2; });
 
  680    auto seqOdd = llvm::seq<int64_t>(n / 2);
 
  682        llvm::map_to_vector(seqOdd, [](
int i) { 
return i * 2 + 1; });
 
  685    auto shuffleEven = spirv::VectorShuffleOp::create(
 
  686        rewriter, loc, newResultType, sourceVector, sourceVector,
 
  687        rewriter.getI32ArrayAttr(indicesEven));
 
  689    auto shuffleOdd = spirv::VectorShuffleOp::create(
 
  690        rewriter, loc, newResultType, sourceVector, sourceVector,
 
  691        rewriter.getI32ArrayAttr(indicesOdd));
 
  693    rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
 
  698struct VectorLoadOpConverter final
 
  699    : 
public OpConversionPattern<vector::LoadOp> {
 
  703  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
 
  704                  ConversionPatternRewriter &rewriter)
 const override {
 
  705    auto memrefType = loadOp.getMemRefType();
 
  707        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
 
  709      return rewriter.notifyMatchFailure(
 
  710          loadOp, 
"expected spirv.storage_class memory space");
 
  712    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
 
  713    auto loc = loadOp.getLoc();
 
  716                             adaptor.getIndices(), loc, rewriter);
 
  718      return rewriter.notifyMatchFailure(
 
  719          loadOp, 
"failed to get memref element pointer");
 
  721    spirv::StorageClass storageClass = attr.getValue();
 
  722    auto vectorType = loadOp.getVectorType();
 
  725    auto spirvVectorType = typeConverter.convertType(vectorType);
 
  726    if (!spirvVectorType)
 
  727      return rewriter.notifyMatchFailure(loadOp, 
"unsupported vector type");
 
  731    std::optional<uint64_t> alignment = loadOp.getAlignment();
 
  732    if (alignment > std::numeric_limits<uint32_t>::max()) {
 
  733      return rewriter.notifyMatchFailure(loadOp,
 
  734                                         "invalid alignment requirement");
 
  737    auto memoryAccess = spirv::MemoryAccess::None;
 
  738    spirv::MemoryAccessAttr memoryAccessAttr;
 
  739    IntegerAttr alignmentAttr;
 
  740    if (alignment.has_value()) {
 
  741      memoryAccess |= spirv::MemoryAccess::Aligned;
 
  743          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
 
  744      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
 
  750    Value castedAccessChain =
 
  751        (vectorType.getNumElements() == 1)
 
  753            : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
 
  756    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
 
  758                                               memoryAccessAttr, alignmentAttr);
 
  764struct VectorStoreOpConverter final
 
  765    : 
public OpConversionPattern<vector::StoreOp> {
 
  769  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
 
  770                  ConversionPatternRewriter &rewriter)
 const override {
 
  771    auto memrefType = storeOp.getMemRefType();
 
  773        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
 
  775      return rewriter.notifyMatchFailure(
 
  776          storeOp, 
"expected spirv.storage_class memory space");
 
  778    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
 
  779    auto loc = storeOp.getLoc();
 
  782                             adaptor.getIndices(), loc, rewriter);
 
  784      return rewriter.notifyMatchFailure(
 
  785          storeOp, 
"failed to get memref element pointer");
 
  787    std::optional<uint64_t> alignment = storeOp.getAlignment();
 
  788    if (alignment > std::numeric_limits<uint32_t>::max()) {
 
  789      return rewriter.notifyMatchFailure(storeOp,
 
  790                                         "invalid alignment requirement");
 
  793    spirv::StorageClass storageClass = attr.getValue();
 
  794    auto vectorType = storeOp.getVectorType();
 
  800    Value castedAccessChain =
 
  801        (vectorType.getNumElements() == 1)
 
  803            : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
 
  806    auto memoryAccess = spirv::MemoryAccess::None;
 
  807    spirv::MemoryAccessAttr memoryAccessAttr;
 
  808    IntegerAttr alignmentAttr;
 
  809    if (alignment.has_value()) {
 
  810      memoryAccess |= spirv::MemoryAccess::Aligned;
 
  812          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
 
  813      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
 
  816    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
 
  817        storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
 
  824struct VectorReductionToIntDotProd final
 
  828  LogicalResult matchAndRewrite(vector::ReductionOp op,
 
  829                                PatternRewriter &rewriter)
 const override {
 
  830    if (op.getKind() != vector::CombiningKind::ADD)
 
  833    auto resultType = dyn_cast<IntegerType>(op.getType());
 
  838    if (!llvm::is_contained({32, 64}, resultBitwidth))
 
  841    VectorType inVecTy = op.getSourceVectorType();
 
  842    if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
 
  843        inVecTy.getShape().size() != 1 || inVecTy.isScalable())
 
  846    auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
 
  849          op, 
"reduction operand is not 'arith.muli'");
 
  851    if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
 
  852                             spirv::SDotAccSatOp, 
false>(op, 
mul, rewriter)))
 
  855    if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
 
  856                             spirv::UDotAccSatOp, 
false>(op, 
mul, rewriter)))
 
  859    if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
 
  860                             spirv::SUDotAccSatOp, 
false>(op, 
mul, rewriter)))
 
  863    if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
 
  864                             spirv::SUDotAccSatOp, 
true>(op, 
mul, rewriter)))
 
  871  template <
typename LhsExtensionOp, 
typename RhsExtensionOp, 
typename DotOp,
 
  872            typename DotAccOp, 
bool SwapOperands>
 
  873  static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp 
mul,
 
  874                                  PatternRewriter &rewriter) {
 
  875    auto lhs = 
mul.getLhs().getDefiningOp<LhsExtensionOp>();
 
  878    Value lhsIn = 
lhs.getIn();
 
  879    auto lhsInType = cast<VectorType>(lhsIn.
getType());
 
  880    if (!lhsInType.getElementType().isInteger(8))
 
  883    auto rhs = 
mul.getRhs().getDefiningOp<RhsExtensionOp>();
 
  886    Value rhsIn = 
rhs.getIn();
 
  887    auto rhsInType = cast<VectorType>(rhsIn.
getType());
 
  888    if (!rhsInType.getElementType().isInteger(8))
 
  891    if (op.getSourceVectorType().getNumElements() == 3) {
 
  892      IntegerType i8Type = rewriter.
getI8Type();
 
  893      auto v4i8Type = VectorType::get({4}, i8Type);
 
  894      Location loc = op.getLoc();
 
  895      Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
 
  896      lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
 
  898      rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
 
  905      std::swap(lhsIn, rhsIn);
 
  907    if (Value acc = op.getAcc()) {
 
  919struct VectorReductionToFPDotProd final
 
  920    : OpConversionPattern<vector::ReductionOp> {
 
  924  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
 
  925                  ConversionPatternRewriter &rewriter)
 const override {
 
  926    if (op.getKind() != vector::CombiningKind::ADD)
 
  927      return rewriter.notifyMatchFailure(op, 
"combining kind is not 'add'");
 
  929    auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
 
  931      return rewriter.notifyMatchFailure(op, 
"result is not a float");
 
  933    Value vec = adaptor.getVector();
 
  934    Value acc = adaptor.getAcc();
 
  936    auto vectorType = dyn_cast<VectorType>(vec.
getType());
 
  938      assert(isa<FloatType>(vec.
getType()) &&
 
  939             "Expected the vector to be scalarized");
 
  941        rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
 
  945      rewriter.replaceOp(op, vec);
 
  949    Location loc = op.getLoc();
 
  960          rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
 
  961      oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
 
  962      rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
 
  967    Value res = spirv::DotOp::create(rewriter, loc, resultType, 
lhs, 
rhs);
 
  969      res = spirv::FAddOp::create(rewriter, loc, acc, res);
 
  971    rewriter.replaceOp(op, res);
 
  976struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
 
  980  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
 
  981                  ConversionPatternRewriter &rewriter)
 const override {
 
  982    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
 
  983    Type dstType = typeConverter.convertType(stepOp.getType());
 
  987    Location loc = stepOp.getLoc();
 
  988    int64_t numElements = stepOp.getType().getNumElements();
 
  990        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
 
  994    if (numElements == 1) {
 
  995      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
 
  996      rewriter.replaceOp(stepOp, zero);
 
 1000    SmallVector<Value> source;
 
 1001    source.reserve(numElements);
 
 1002    for (int64_t i = 0; i < numElements; ++i) {
 
 1003      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
 
 1005          spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
 
 1006      source.push_back(constOp);
 
 1008    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
 
 1014struct VectorToElementOpConvert final
 
 1015    : OpConversionPattern<vector::ToElementsOp> {
 
 1019  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
 
 1020                  ConversionPatternRewriter &rewriter)
 const override {
 
 1022    SmallVector<Value> results(toElementsOp->getNumResults());
 
 1023    Location loc = toElementsOp.getLoc();
 
 1028    if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
 
 1029      results[0] = adaptor.getSource();
 
 1030      rewriter.replaceOp(toElementsOp, results);
 
 1034    Type srcElementType = toElementsOp.getElements().getType().front();
 
 1035    Type elementType = getTypeConverter()->convertType(srcElementType);
 
 1037      return rewriter.notifyMatchFailure(
 
 1039          llvm::formatv(
"failed to convert element type '{0}' to SPIR-V",
 
 1042    for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
 
 1045      if (element.use_empty())
 
 1048      Value 
result = spirv::CompositeExtractOp::create(
 
 1049          rewriter, loc, elementType, adaptor.getSource(),
 
 1050          rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
 
 1054    rewriter.replaceOp(toElementsOp, results);
 
 1060#define CL_INT_MAX_MIN_OPS                                                     \ 
 1061  spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp 
 
 1063#define GL_INT_MAX_MIN_OPS                                                     \ 
 1064  spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp 
 
 1066#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp 
 1067#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp 
 1072      VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
 
 1073      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
 
 1074      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
 
 1075      VectorToElementOpConvert, VectorInsertOpConvert,
 
 1076      VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
 
 1077      VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
 
 1078      VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
 
 1079      VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
 
 1080      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
 
 1081      VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
 
 1082      VectorScalarBroadcastPattern, VectorLoadOpConverter,
 
 1083      VectorStoreOpConverter, VectorStepOpConvert>(
 
 1088  patterns.add<VectorReductionToFPDotProd>(typeConverter, 
patterns.getContext(),
 
 
static constexpr unsigned getNumBits()
 
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
 
#define INT_AND_FLOAT_CASE(kind, iop, fop)
 
#define INT_OR_FLOAT_CASE(kind, fop)
 
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
 
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
 
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
 
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
 
Type conversion from builtin types to SPIR-V types for shader interface.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
 
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
 
Type getType() const
Return the type of this value.
 
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
 
static PointerType get(Type pointeeType, StorageClass storageClass)
 
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
 
Include the generated interface declarations.
 
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
 
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
 
const FrozenRewritePatternSet & patterns
 
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
 
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...