28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
43 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51 if (
auto vectorType = dyn_cast<VectorType>(type))
52 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
58struct VectorShapeCast final :
public OpConversionPattern<vector::ShapeCastOp> {
62 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const override {
64 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
70 if (dstType == adaptor.getSource().getType() ||
71 shapeCastOp.getResultVectorType().getNumElements() == 1) {
72 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
81struct VectorBitcastConvert final
82 :
public OpConversionPattern<vector::BitCastOp> {
86 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter)
const override {
88 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
92 if (dstType == adaptor.getSource().getType()) {
93 rewriter.replaceOp(bitcastOp, adaptor.getSource());
100 Type srcType = adaptor.getSource().getType();
102 return rewriter.notifyMatchFailure(
104 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
108 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
109 adaptor.getSource());
114struct VectorBroadcastConvert final
115 :
public OpConversionPattern<vector::BroadcastOp> {
119 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
122 getTypeConverter()->convertType(castOp.getResultVectorType());
126 if (isa<spirv::ScalarType>(resultType)) {
127 rewriter.replaceOp(castOp, adaptor.getSource());
131 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
132 adaptor.getSource());
133 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
145static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147 int64_t kPoisonIndex,
unsigned vectorSize) {
148 if (llvm::isPowerOf2_32(vectorSize)) {
149 Value inBoundsMask = spirv::ConstantOp::create(
150 rewriter, loc, dynamicIndex.
getType(),
151 rewriter.getIntegerAttr(dynamicIndex.
getType(), vectorSize - 1));
152 return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
155 Value poisonIndex = spirv::ConstantOp::create(
156 rewriter, loc, dynamicIndex.
getType(),
157 rewriter.getIntegerAttr(dynamicIndex.
getType(), kPoisonIndex));
159 spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
160 return spirv::SelectOp::create(
161 rewriter, loc, cmpResult,
162 spirv::ConstantOp::getZero(dynamicIndex.
getType(), loc, rewriter),
166struct VectorExtractOpConvert final
167 :
public OpConversionPattern<vector::ExtractOp> {
171 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter)
const override {
173 Type dstType = getTypeConverter()->convertType(extractOp.getType());
177 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
178 rewriter.replaceOp(extractOp, adaptor.getSource());
182 if (std::optional<int64_t>
id =
184 if (
id == vector::ExtractOp::kPoisonIndex)
185 return rewriter.notifyMatchFailure(
187 "Static use of poison index handled elsewhere (folded to poison)");
188 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
189 extractOp, dstType, adaptor.getSource(),
190 rewriter.getI32ArrayAttr(
id.value()));
192 Value sanitizedIndex = sanitizeDynamicIndex(
193 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
194 vector::ExtractOp::kPoisonIndex,
195 extractOp.getSourceVectorType().getNumElements());
196 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
197 extractOp, dstType, adaptor.getSource(), sanitizedIndex);
203struct VectorExtractStridedSliceOpConvert final
204 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
208 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter)
const override {
210 Type dstType = getTypeConverter()->convertType(extractOp.getType());
220 Value srcVector = adaptor.getOperands().front();
223 if (isa<spirv::ScalarType>(dstType)) {
224 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
229 SmallVector<int32_t, 2>
indices(size);
232 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233 extractOp, dstType, srcVector, srcVector,
234 rewriter.getI32ArrayAttr(
indices));
240template <
class SPIRVFMAOp>
241struct VectorFmaOpConvert final :
public OpConversionPattern<vector::FMAOp> {
245 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
250 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
251 adaptor.getRhs(), adaptor.getAcc());
256struct VectorFromElementsOpConvert final
257 :
public OpConversionPattern<vector::FromElementsOp> {
261 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 Type resultType = getTypeConverter()->convertType(op.getType());
267 if (isa<spirv::ScalarType>(resultType)) {
270 rewriter.replaceOp(op, elements[0]);
275 assert(cast<VectorType>(resultType).getRank() == 1);
276 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
282struct VectorInsertOpConvert final
283 :
public OpConversionPattern<vector::InsertOp> {
287 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 if (isa<VectorType>(insertOp.getValueToStoreType()))
290 return rewriter.notifyMatchFailure(insertOp,
"unsupported vector source");
291 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
292 return rewriter.notifyMatchFailure(insertOp,
293 "unsupported dest vector type");
296 if (insertOp.getValueToStoreType().isIntOrFloat() &&
297 insertOp.getDestVectorType().getNumElements() == 1) {
298 rewriter.replaceOp(insertOp, adaptor.getValueToStore());
302 if (std::optional<int64_t>
id =
304 if (
id == vector::InsertOp::kPoisonIndex)
305 return rewriter.notifyMatchFailure(
307 "Static use of poison index handled elsewhere (folded to poison)");
308 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
309 insertOp, adaptor.getValueToStore(), adaptor.getDest(),
id.value());
311 Value sanitizedIndex = sanitizeDynamicIndex(
312 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
313 vector::InsertOp::kPoisonIndex,
314 insertOp.getDestVectorType().getNumElements());
315 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
316 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
323struct VectorInsertStridedSliceOpConvert final
324 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
328 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter)
const override {
330 Value srcVector = adaptor.getOperands().front();
331 Value dstVector = adaptor.getOperands().back();
338 if (isa<spirv::ScalarType>(srcVector.
getType())) {
339 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
340 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
341 insertOp, dstVector.
getType(), srcVector, dstVector,
342 rewriter.getI32ArrayAttr(offset));
346 uint64_t totalSize = cast<VectorType>(dstVector.
getType()).getNumElements();
347 uint64_t insertSize =
348 cast<VectorType>(srcVector.
getType()).getNumElements();
350 SmallVector<int32_t, 2>
indices(totalSize);
352 std::iota(
indices.begin() + offset,
indices.begin() + offset + insertSize,
355 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
356 insertOp, dstVector.
getType(), dstVector, srcVector,
357 rewriter.getI32ArrayAttr(
indices));
364 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
365 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
366 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
368 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
371 for (
int i = 0; i < numElements; ++i) {
372 values.push_back(spirv::CompositeExtractOp::create(
373 rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
374 rewriter.getI32ArrayAttr({i})));
377 values.push_back(
acc);
382struct ReductionRewriteInfo {
384 SmallVector<Value> extractedElements;
387FailureOr<ReductionRewriteInfo>
static getReductionInfo(
388 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
389 ConversionPatternRewriter &rewriter,
const TypeConverter &typeConverter) {
390 Type resultType = typeConverter.convertType(op.getType());
394 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
395 if (!srcVectorType || srcVectorType.getRank() != 1)
396 return rewriter.notifyMatchFailure(op,
"not a 1-D vector source");
399 extractAllElements(op, adaptor, srcVectorType, rewriter);
401 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
404template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
405 typename SPIRVSMinOp>
406struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
410 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter)
const override {
413 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
414 if (
failed(reductionInfo))
417 auto [resultType, extractedElements] = *reductionInfo;
418 Location loc = reduceOp->getLoc();
419 Value
result = extractedElements.front();
420 for (Value next : llvm::drop_begin(extractedElements)) {
421 switch (reduceOp.getKind()) {
423#define INT_AND_FLOAT_CASE(kind, iop, fop) \
424 case vector::CombiningKind::kind: \
425 if (llvm::isa<IntegerType>(resultType)) { \
426 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
428 assert(llvm::isa<FloatType>(resultType)); \
429 result = spirv::fop::create(rewriter, loc, resultType, result, next); \
433#define INT_OR_FLOAT_CASE(kind, fop) \
434 case vector::CombiningKind::kind: \
435 result = fop::create(rewriter, loc, resultType, result, next); \
445 case vector::CombiningKind::AND:
446 case vector::CombiningKind::OR:
447 case vector::CombiningKind::XOR:
448 return rewriter.notifyMatchFailure(reduceOp,
"unimplemented");
450 return rewriter.notifyMatchFailure(reduceOp,
"not handled here");
452#undef INT_AND_FLOAT_CASE
453#undef INT_OR_FLOAT_CASE
456 rewriter.replaceOp(reduceOp,
result);
461template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
462struct VectorReductionFloatMinMax final
463 : OpConversionPattern<vector::ReductionOp> {
467 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override {
470 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
471 if (
failed(reductionInfo))
474 auto [resultType, extractedElements] = *reductionInfo;
475 Location loc = reduceOp->getLoc();
476 Value
result = extractedElements.front();
477 for (Value next : llvm::drop_begin(extractedElements)) {
478 switch (reduceOp.getKind()) {
480#define INT_OR_FLOAT_CASE(kind, fop) \
481 case vector::CombiningKind::kind: \
482 result = fop::create(rewriter, loc, resultType, result, next); \
491 return rewriter.notifyMatchFailure(reduceOp,
"not handled here");
493#undef INT_OR_FLOAT_CASE
496 rewriter.replaceOp(reduceOp,
result);
501class VectorScalarBroadcastPattern final
502 :
public OpConversionPattern<vector::BroadcastOp> {
507 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
508 ConversionPatternRewriter &rewriter)
const override {
509 if (isa<VectorType>(op.getSourceType())) {
510 return rewriter.notifyMatchFailure(
511 op,
"only conversion of 'broadcast from scalar' is supported");
513 Type dstType = getTypeConverter()->convertType(op.getType());
516 if (isa<spirv::ScalarType>(dstType)) {
517 rewriter.replaceOp(op, adaptor.getSource());
519 auto dstVecType = cast<VectorType>(dstType);
520 SmallVector<Value, 4> source(dstVecType.getNumElements(),
521 adaptor.getSource());
522 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
529struct VectorShuffleOpConvert final
530 :
public OpConversionPattern<vector::ShuffleOp> {
534 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter)
const override {
536 VectorType oldResultType = shuffleOp.getResultVectorType();
537 Type newResultType = getTypeConverter()->convertType(oldResultType);
539 return rewriter.notifyMatchFailure(shuffleOp,
540 "unsupported result vector type");
542 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
544 VectorType oldV1Type = shuffleOp.getV1VectorType();
545 VectorType oldV2Type = shuffleOp.getV2VectorType();
549 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
550 oldResultType.getNumElements() > 1) {
551 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
552 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
553 rewriter.getI32ArrayAttr(mask));
560 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
561 Value scalarOrVec, int32_t idx) -> Value {
562 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
563 return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
566 assert(idx == 0 &&
"Invalid scalar element index");
570 int32_t numV1Elems = oldV1Type.getNumElements();
571 SmallVector<Value> newOperands(mask.size());
572 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
573 Value vec = adaptor.getV1();
574 int32_t elementIdx = shuffleIdx;
575 if (elementIdx >= numV1Elems) {
576 vec = adaptor.getV2();
577 elementIdx -= numV1Elems;
580 newOperand = getElementAtIdx(vec, elementIdx);
584 if (newOperands.size() == 1) {
585 rewriter.replaceOp(shuffleOp, newOperands.front());
589 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
590 shuffleOp, newResultType, newOperands);
595struct VectorInterleaveOpConvert final
596 :
public OpConversionPattern<vector::InterleaveOp> {
600 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
601 ConversionPatternRewriter &rewriter)
const override {
603 VectorType oldResultType = interleaveOp.getResultVectorType();
604 Type newResultType = getTypeConverter()->convertType(oldResultType);
606 return rewriter.notifyMatchFailure(interleaveOp,
607 "unsupported result vector type");
610 VectorType sourceType = interleaveOp.getSourceVectorType();
611 int n = sourceType.getNumElements();
617 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
618 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
619 interleaveOp, newResultType, newOperands);
623 auto seq = llvm::seq<int64_t>(2 * n);
624 auto indices = llvm::map_to_vector(
625 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
628 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
629 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
630 rewriter.getI32ArrayAttr(
indices));
636struct VectorDeinterleaveOpConvert final
637 :
public OpConversionPattern<vector::DeinterleaveOp> {
641 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter)
const override {
645 VectorType oldResultType = deinterleaveOp.getResultVectorType();
646 Type newResultType = getTypeConverter()->convertType(oldResultType);
648 return rewriter.notifyMatchFailure(deinterleaveOp,
649 "unsupported result vector type");
651 Location loc = deinterleaveOp->getLoc();
654 Value sourceVector = adaptor.getSource();
655 VectorType sourceType = deinterleaveOp.getSourceVectorType();
656 int n = sourceType.getNumElements();
662 auto elem0 = spirv::CompositeExtractOp::create(
663 rewriter, loc, newResultType, sourceVector,
664 rewriter.getI32ArrayAttr({0}));
666 auto elem1 = spirv::CompositeExtractOp::create(
667 rewriter, loc, newResultType, sourceVector,
668 rewriter.getI32ArrayAttr({1}));
670 rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
675 auto seqEven = llvm::seq<int64_t>(n / 2);
677 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
680 auto seqOdd = llvm::seq<int64_t>(n / 2);
682 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
685 auto shuffleEven = spirv::VectorShuffleOp::create(
686 rewriter, loc, newResultType, sourceVector, sourceVector,
687 rewriter.getI32ArrayAttr(indicesEven));
689 auto shuffleOdd = spirv::VectorShuffleOp::create(
690 rewriter, loc, newResultType, sourceVector, sourceVector,
691 rewriter.getI32ArrayAttr(indicesOdd));
693 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
698struct VectorLoadOpConverter final
699 :
public OpConversionPattern<vector::LoadOp> {
703 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
704 ConversionPatternRewriter &rewriter)
const override {
705 auto memrefType = loadOp.getMemRefType();
707 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
709 return rewriter.notifyMatchFailure(
710 loadOp,
"expected spirv.storage_class memory space");
712 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
713 auto loc = loadOp.getLoc();
716 adaptor.getIndices(), loc, rewriter);
718 return rewriter.notifyMatchFailure(
719 loadOp,
"failed to get memref element pointer");
721 spirv::StorageClass storageClass = attr.getValue();
722 auto vectorType = loadOp.getVectorType();
725 auto spirvVectorType = typeConverter.convertType(vectorType);
726 if (!spirvVectorType)
727 return rewriter.notifyMatchFailure(loadOp,
"unsupported vector type");
731 std::optional<uint64_t> alignment = loadOp.getAlignment();
732 if (alignment > std::numeric_limits<uint32_t>::max()) {
733 return rewriter.notifyMatchFailure(loadOp,
734 "invalid alignment requirement");
737 auto memoryAccess = spirv::MemoryAccess::None;
738 spirv::MemoryAccessAttr memoryAccessAttr;
739 IntegerAttr alignmentAttr;
740 if (alignment.has_value()) {
741 memoryAccess |= spirv::MemoryAccess::Aligned;
743 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
744 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
750 Value castedAccessChain =
751 (vectorType.getNumElements() == 1)
753 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
756 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
758 memoryAccessAttr, alignmentAttr);
764struct VectorStoreOpConverter final
765 :
public OpConversionPattern<vector::StoreOp> {
769 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
770 ConversionPatternRewriter &rewriter)
const override {
771 auto memrefType = storeOp.getMemRefType();
773 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
775 return rewriter.notifyMatchFailure(
776 storeOp,
"expected spirv.storage_class memory space");
778 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
779 auto loc = storeOp.getLoc();
782 adaptor.getIndices(), loc, rewriter);
784 return rewriter.notifyMatchFailure(
785 storeOp,
"failed to get memref element pointer");
787 std::optional<uint64_t> alignment = storeOp.getAlignment();
788 if (alignment > std::numeric_limits<uint32_t>::max()) {
789 return rewriter.notifyMatchFailure(storeOp,
790 "invalid alignment requirement");
793 spirv::StorageClass storageClass = attr.getValue();
794 auto vectorType = storeOp.getVectorType();
800 Value castedAccessChain =
801 (vectorType.getNumElements() == 1)
803 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
806 auto memoryAccess = spirv::MemoryAccess::None;
807 spirv::MemoryAccessAttr memoryAccessAttr;
808 IntegerAttr alignmentAttr;
809 if (alignment.has_value()) {
810 memoryAccess |= spirv::MemoryAccess::Aligned;
812 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
813 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
816 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
817 storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
824struct VectorReductionToIntDotProd final
828 LogicalResult matchAndRewrite(vector::ReductionOp op,
829 PatternRewriter &rewriter)
const override {
830 if (op.getKind() != vector::CombiningKind::ADD)
833 auto resultType = dyn_cast<IntegerType>(op.getType());
838 if (!llvm::is_contained({32, 64}, resultBitwidth))
841 VectorType inVecTy = op.getSourceVectorType();
842 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
843 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
846 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
849 op,
"reduction operand is not 'arith.muli'");
851 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
852 spirv::SDotAccSatOp,
false>(op,
mul, rewriter)))
855 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
856 spirv::UDotAccSatOp,
false>(op,
mul, rewriter)))
859 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
860 spirv::SUDotAccSatOp,
false>(op,
mul, rewriter)))
863 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
864 spirv::SUDotAccSatOp,
true>(op,
mul, rewriter)))
871 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
872 typename DotAccOp,
bool SwapOperands>
873 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp
mul,
874 PatternRewriter &rewriter) {
875 auto lhs =
mul.getLhs().getDefiningOp<LhsExtensionOp>();
878 Value lhsIn =
lhs.getIn();
879 auto lhsInType = cast<VectorType>(lhsIn.
getType());
880 if (!lhsInType.getElementType().isInteger(8))
883 auto rhs =
mul.getRhs().getDefiningOp<RhsExtensionOp>();
886 Value rhsIn =
rhs.getIn();
887 auto rhsInType = cast<VectorType>(rhsIn.
getType());
888 if (!rhsInType.getElementType().isInteger(8))
891 if (op.getSourceVectorType().getNumElements() == 3) {
892 IntegerType i8Type = rewriter.
getI8Type();
893 auto v4i8Type = VectorType::get({4}, i8Type);
894 Location loc = op.getLoc();
895 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
896 lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
898 rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
905 std::swap(lhsIn, rhsIn);
907 if (Value acc = op.getAcc()) {
919struct VectorReductionToFPDotProd final
920 : OpConversionPattern<vector::ReductionOp> {
924 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
925 ConversionPatternRewriter &rewriter)
const override {
926 if (op.getKind() != vector::CombiningKind::ADD)
927 return rewriter.notifyMatchFailure(op,
"combining kind is not 'add'");
929 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
931 return rewriter.notifyMatchFailure(op,
"result is not a float");
933 Value vec = adaptor.getVector();
934 Value acc = adaptor.getAcc();
936 auto vectorType = dyn_cast<VectorType>(vec.
getType());
938 assert(isa<FloatType>(vec.
getType()) &&
939 "Expected the vector to be scalarized");
941 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
945 rewriter.replaceOp(op, vec);
949 Location loc = op.getLoc();
960 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
961 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
962 rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
967 Value res = spirv::DotOp::create(rewriter, loc, resultType,
lhs,
rhs);
969 res = spirv::FAddOp::create(rewriter, loc, acc, res);
971 rewriter.replaceOp(op, res);
976struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
980 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
981 ConversionPatternRewriter &rewriter)
const override {
982 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
983 Type dstType = typeConverter.convertType(stepOp.getType());
987 Location loc = stepOp.getLoc();
988 int64_t numElements = stepOp.getType().getNumElements();
990 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
994 if (numElements == 1) {
995 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
996 rewriter.replaceOp(stepOp, zero);
1000 SmallVector<Value> source;
1001 source.reserve(numElements);
1002 for (int64_t i = 0; i < numElements; ++i) {
1003 Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1005 spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1006 source.push_back(constOp);
1008 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1014struct VectorToElementOpConvert final
1015 : OpConversionPattern<vector::ToElementsOp> {
1019 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1020 ConversionPatternRewriter &rewriter)
const override {
1022 SmallVector<Value> results(toElementsOp->getNumResults());
1023 Location loc = toElementsOp.getLoc();
1028 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1029 results[0] = adaptor.getSource();
1030 rewriter.replaceOp(toElementsOp, results);
1034 Type srcElementType = toElementsOp.getElements().getType().front();
1035 Type elementType = getTypeConverter()->convertType(srcElementType);
1037 return rewriter.notifyMatchFailure(
1039 llvm::formatv(
"failed to convert element type '{0}' to SPIR-V",
1042 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1045 if (element.use_empty())
1048 Value
result = spirv::CompositeExtractOp::create(
1049 rewriter, loc, elementType, adaptor.getSource(),
1050 rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1054 rewriter.replaceOp(toElementsOp, results);
1060#define CL_INT_MAX_MIN_OPS \
1061 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1063#define GL_INT_MAX_MIN_OPS \
1064 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1066#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1067#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1072 VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1073 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1074 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1075 VectorToElementOpConvert, VectorInsertOpConvert,
1076 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1077 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1078 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1079 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1080 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1081 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1082 VectorScalarBroadcastPattern, VectorLoadOpConverter,
1083 VectorStoreOpConverter, VectorStepOpConvert>(
1088 patterns.add<VectorReductionToFPDotProd>(typeConverter,
patterns.getContext(),
static constexpr unsigned getNumBits()
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...