28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
43 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51 if (
auto vectorType = dyn_cast<VectorType>(type))
52 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
58struct VectorShapeCast final :
public OpConversionPattern<vector::ShapeCastOp> {
62 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const override {
64 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
70 if (dstType == adaptor.getSource().getType() ||
71 shapeCastOp.getResultVectorType().getNumElements() == 1) {
72 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
81struct VectorBitcastConvert final
82 :
public OpConversionPattern<vector::BitCastOp> {
86 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter)
const override {
88 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
92 if (dstType == adaptor.getSource().getType()) {
93 rewriter.replaceOp(bitcastOp, adaptor.getSource());
100 Type srcType = adaptor.getSource().getType();
102 return rewriter.notifyMatchFailure(
104 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
108 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
109 adaptor.getSource());
114struct VectorBroadcastConvert final
115 :
public OpConversionPattern<vector::BroadcastOp> {
119 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
122 getTypeConverter()->convertType(castOp.getResultVectorType());
126 if (isa<spirv::ScalarType>(resultType)) {
127 rewriter.replaceOp(castOp, adaptor.getSource());
131 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
132 adaptor.getSource());
133 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
145static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147 int64_t kPoisonIndex,
unsigned vectorSize) {
148 if (llvm::isPowerOf2_32(vectorSize)) {
149 Value inBoundsMask = spirv::ConstantOp::create(
150 rewriter, loc, dynamicIndex.
getType(),
151 rewriter.getIntegerAttr(dynamicIndex.
getType(), vectorSize - 1));
152 return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
155 Value poisonIndex = spirv::ConstantOp::create(
156 rewriter, loc, dynamicIndex.
getType(),
157 rewriter.getIntegerAttr(dynamicIndex.
getType(), kPoisonIndex));
159 spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
160 return spirv::SelectOp::create(
161 rewriter, loc, cmpResult,
162 spirv::ConstantOp::getZero(dynamicIndex.
getType(), loc, rewriter),
166struct VectorExtractOpConvert final
167 :
public OpConversionPattern<vector::ExtractOp> {
171 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter)
const override {
173 Type dstType = getTypeConverter()->convertType(extractOp.getType());
177 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
178 rewriter.replaceOp(extractOp, adaptor.getSource());
182 if (std::optional<int64_t>
id =
184 if (
id == vector::ExtractOp::kPoisonIndex)
185 return rewriter.notifyMatchFailure(
187 "Static use of poison index handled elsewhere (folded to poison)");
188 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
189 extractOp, dstType, adaptor.getSource(),
190 rewriter.getI32ArrayAttr(
id.value()));
192 Value sanitizedIndex = sanitizeDynamicIndex(
193 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
194 vector::ExtractOp::kPoisonIndex,
195 extractOp.getSourceVectorType().getNumElements());
196 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
197 extractOp, dstType, adaptor.getSource(), sanitizedIndex);
203struct VectorExtractStridedSliceOpConvert final
204 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
208 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter)
const override {
210 Type dstType = getTypeConverter()->convertType(extractOp.getType());
220 Value srcVector = adaptor.getOperands().front();
223 if (isa<spirv::ScalarType>(dstType)) {
224 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
229 SmallVector<int32_t, 2>
indices(size);
232 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233 extractOp, dstType, srcVector, srcVector,
234 rewriter.getI32ArrayAttr(
indices));
240template <
class SPIRVFMAOp>
241struct VectorFmaOpConvert final :
public OpConversionPattern<vector::FMAOp> {
245 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
250 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
251 adaptor.getRhs(), adaptor.getAcc());
256struct VectorFromElementsOpConvert final
257 :
public OpConversionPattern<vector::FromElementsOp> {
261 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 Type resultType = getTypeConverter()->convertType(op.getType());
267 if (isa<spirv::ScalarType>(resultType)) {
270 rewriter.replaceOp(op, elements[0]);
275 assert(cast<VectorType>(resultType).getRank() == 1);
276 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
282struct VectorInsertOpConvert final
283 :
public OpConversionPattern<vector::InsertOp> {
287 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 if (isa<VectorType>(insertOp.getValueToStoreType()))
290 return rewriter.notifyMatchFailure(insertOp,
"unsupported vector source");
291 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
292 return rewriter.notifyMatchFailure(insertOp,
293 "unsupported dest vector type");
296 if (insertOp.getValueToStoreType().isIntOrFloat() &&
297 insertOp.getDestVectorType().getNumElements() == 1) {
298 rewriter.replaceOp(insertOp, adaptor.getValueToStore());
302 if (std::optional<int64_t>
id =
304 if (
id == vector::InsertOp::kPoisonIndex)
305 return rewriter.notifyMatchFailure(
307 "Static use of poison index handled elsewhere (folded to poison)");
308 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
309 insertOp, adaptor.getValueToStore(), adaptor.getDest(),
id.value());
311 Value sanitizedIndex = sanitizeDynamicIndex(
312 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
313 vector::InsertOp::kPoisonIndex,
314 insertOp.getDestVectorType().getNumElements());
315 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
316 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
323struct VectorInsertStridedSliceOpConvert final
324 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
328 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter)
const override {
330 Value srcVector = adaptor.getOperands().front();
331 Value dstVector = adaptor.getOperands().back();
338 if (isa<spirv::ScalarType>(srcVector.
getType())) {
339 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
340 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
341 insertOp, dstVector.
getType(), srcVector, dstVector,
342 rewriter.getI32ArrayAttr(offset));
346 uint64_t totalSize = cast<VectorType>(dstVector.
getType()).getNumElements();
347 uint64_t insertSize =
348 cast<VectorType>(srcVector.
getType()).getNumElements();
350 SmallVector<int32_t, 2>
indices(totalSize);
352 std::iota(
indices.begin() + offset,
indices.begin() + offset + insertSize,
355 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
356 insertOp, dstVector.
getType(), dstVector, srcVector,
357 rewriter.getI32ArrayAttr(
indices));
364 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
365 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
366 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
368 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
371 for (
int i = 0; i < numElements; ++i) {
372 values.push_back(spirv::CompositeExtractOp::create(
373 rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
374 rewriter.getI32ArrayAttr({i})));
377 values.push_back(
acc);
382struct ReductionRewriteInfo {
384 SmallVector<Value> extractedElements;
387FailureOr<ReductionRewriteInfo>
static getReductionInfo(
388 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
389 ConversionPatternRewriter &rewriter,
const TypeConverter &typeConverter) {
390 Type resultType = typeConverter.convertType(op.getType());
394 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
395 if (!srcVectorType || srcVectorType.getRank() != 1)
396 return rewriter.notifyMatchFailure(op,
"not a 1-D vector source");
399 extractAllElements(op, adaptor, srcVectorType, rewriter);
401 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
404template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
405 typename SPIRVSMinOp>
406struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
410 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter)
const override {
413 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
414 if (
failed(reductionInfo))
417 auto [resultType, extractedElements] = *reductionInfo;
418 Location loc = reduceOp->getLoc();
422 vector::CombiningKind kind = reduceOp.getKind();
424 if (kind == vector::CombiningKind::OR) {
425 Value
result = spirv::AnyOp::create(rewriter, loc, resultType,
426 adaptor.getVector());
427 if (Value acc = adaptor.getAcc())
428 result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
result,
430 rewriter.replaceOp(reduceOp,
result);
434 if (kind == vector::CombiningKind::AND) {
435 Value
result = spirv::AllOp::create(rewriter, loc, resultType,
436 adaptor.getVector());
437 if (Value acc = adaptor.getAcc())
438 result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
440 rewriter.replaceOp(reduceOp,
result);
445 Value
result = extractedElements.front();
446 for (Value next : llvm::drop_begin(extractedElements)) {
447 switch (reduceOp.getKind()) {
449#define INT_AND_FLOAT_CASE(kind, iop, fop) \
450 case vector::CombiningKind::kind: \
451 if (isa<IntegerType>(resultType)) { \
452 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
454 assert(isa<FloatType>(resultType)); \
455 result = spirv::fop::create(rewriter, loc, resultType, result, next); \
459#define INT_OR_FLOAT_CASE(kind, fop) \
460 case vector::CombiningKind::kind: \
461 result = fop::create(rewriter, loc, resultType, result, next); \
464#define INT_CASE(kind, iop) \
465 case vector::CombiningKind::kind: \
466 assert(isa<IntegerType>(resultType)); \
467 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
481 return rewriter.notifyMatchFailure(reduceOp,
"not handled here");
483#undef INT_AND_FLOAT_CASE
484#undef INT_OR_FLOAT_CASE
488 rewriter.replaceOp(reduceOp,
result);
493template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
494struct VectorReductionFloatMinMax final
495 : OpConversionPattern<vector::ReductionOp> {
499 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
500 ConversionPatternRewriter &rewriter)
const override {
502 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
503 if (
failed(reductionInfo))
506 auto [resultType, extractedElements] = *reductionInfo;
507 Location loc = reduceOp->getLoc();
508 Value
result = extractedElements.front();
509 for (Value next : llvm::drop_begin(extractedElements)) {
510 switch (reduceOp.getKind()) {
512#define INT_OR_FLOAT_CASE(kind, fop) \
513 case vector::CombiningKind::kind: \
514 result = fop::create(rewriter, loc, resultType, result, next); \
523 return rewriter.notifyMatchFailure(reduceOp,
"not handled here");
525#undef INT_OR_FLOAT_CASE
528 rewriter.replaceOp(reduceOp,
result);
533class VectorScalarBroadcastPattern final
534 :
public OpConversionPattern<vector::BroadcastOp> {
539 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter)
const override {
541 if (isa<VectorType>(op.getSourceType())) {
542 return rewriter.notifyMatchFailure(
543 op,
"only conversion of 'broadcast from scalar' is supported");
545 Type dstType = getTypeConverter()->convertType(op.getType());
548 if (isa<spirv::ScalarType>(dstType)) {
549 rewriter.replaceOp(op, adaptor.getSource());
551 auto dstVecType = cast<VectorType>(dstType);
552 SmallVector<Value, 4> source(dstVecType.getNumElements(),
553 adaptor.getSource());
554 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
561struct VectorShuffleOpConvert final
562 :
public OpConversionPattern<vector::ShuffleOp> {
566 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
567 ConversionPatternRewriter &rewriter)
const override {
568 VectorType oldResultType = shuffleOp.getResultVectorType();
569 Type newResultType = getTypeConverter()->convertType(oldResultType);
571 return rewriter.notifyMatchFailure(shuffleOp,
572 "unsupported result vector type");
574 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
576 VectorType oldV1Type = shuffleOp.getV1VectorType();
577 VectorType oldV2Type = shuffleOp.getV2VectorType();
581 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
582 oldResultType.getNumElements() > 1) {
583 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
584 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
585 rewriter.getI32ArrayAttr(mask));
592 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
593 Value scalarOrVec, int32_t idx) -> Value {
594 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
595 return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
598 assert(idx == 0 &&
"Invalid scalar element index");
602 int32_t numV1Elems = oldV1Type.getNumElements();
603 SmallVector<Value> newOperands(mask.size());
604 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
605 Value vec = adaptor.getV1();
606 int32_t elementIdx = shuffleIdx;
607 if (elementIdx >= numV1Elems) {
608 vec = adaptor.getV2();
609 elementIdx -= numV1Elems;
612 newOperand = getElementAtIdx(vec, elementIdx);
616 if (newOperands.size() == 1) {
617 rewriter.replaceOp(shuffleOp, newOperands.front());
621 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
622 shuffleOp, newResultType, newOperands);
627struct VectorInterleaveOpConvert final
628 :
public OpConversionPattern<vector::InterleaveOp> {
632 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
635 VectorType oldResultType = interleaveOp.getResultVectorType();
636 Type newResultType = getTypeConverter()->convertType(oldResultType);
638 return rewriter.notifyMatchFailure(interleaveOp,
639 "unsupported result vector type");
642 VectorType sourceType = interleaveOp.getSourceVectorType();
643 int n = sourceType.getNumElements();
649 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
650 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
651 interleaveOp, newResultType, newOperands);
655 auto seq = llvm::seq<int64_t>(2 * n);
656 auto indices = llvm::map_to_vector(
657 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
660 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
661 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
662 rewriter.getI32ArrayAttr(
indices));
668struct VectorDeinterleaveOpConvert final
669 :
public OpConversionPattern<vector::DeinterleaveOp> {
673 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
674 ConversionPatternRewriter &rewriter)
const override {
677 VectorType oldResultType = deinterleaveOp.getResultVectorType();
678 Type newResultType = getTypeConverter()->convertType(oldResultType);
680 return rewriter.notifyMatchFailure(deinterleaveOp,
681 "unsupported result vector type");
683 Location loc = deinterleaveOp->getLoc();
686 Value sourceVector = adaptor.getSource();
687 VectorType sourceType = deinterleaveOp.getSourceVectorType();
688 int n = sourceType.getNumElements();
694 auto elem0 = spirv::CompositeExtractOp::create(
695 rewriter, loc, newResultType, sourceVector,
696 rewriter.getI32ArrayAttr({0}));
698 auto elem1 = spirv::CompositeExtractOp::create(
699 rewriter, loc, newResultType, sourceVector,
700 rewriter.getI32ArrayAttr({1}));
702 rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
707 auto seqEven = llvm::seq<int64_t>(n / 2);
709 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
712 auto seqOdd = llvm::seq<int64_t>(n / 2);
714 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
717 auto shuffleEven = spirv::VectorShuffleOp::create(
718 rewriter, loc, newResultType, sourceVector, sourceVector,
719 rewriter.getI32ArrayAttr(indicesEven));
721 auto shuffleOdd = spirv::VectorShuffleOp::create(
722 rewriter, loc, newResultType, sourceVector, sourceVector,
723 rewriter.getI32ArrayAttr(indicesOdd));
725 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
730struct VectorLoadOpConverter final
731 :
public OpConversionPattern<vector::LoadOp> {
735 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const override {
737 auto memrefType = loadOp.getMemRefType();
739 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
741 return rewriter.notifyMatchFailure(
742 loadOp,
"expected spirv.storage_class memory space");
744 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
745 auto loc = loadOp.getLoc();
748 adaptor.getIndices(), loc, rewriter);
750 return rewriter.notifyMatchFailure(
751 loadOp,
"failed to get memref element pointer");
753 spirv::StorageClass storageClass = attr.getValue();
754 auto vectorType = loadOp.getVectorType();
757 auto spirvVectorType = typeConverter.convertType(vectorType);
758 if (!spirvVectorType)
759 return rewriter.notifyMatchFailure(loadOp,
"unsupported vector type");
763 std::optional<uint64_t> alignment = loadOp.getAlignment();
764 if (alignment > std::numeric_limits<uint32_t>::max()) {
765 return rewriter.notifyMatchFailure(loadOp,
766 "invalid alignment requirement");
769 auto memoryAccess = spirv::MemoryAccess::None;
770 spirv::MemoryAccessAttr memoryAccessAttr;
771 IntegerAttr alignmentAttr;
772 if (alignment.has_value()) {
773 memoryAccess |= spirv::MemoryAccess::Aligned;
775 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
776 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
782 Value castedAccessChain =
783 (vectorType.getNumElements() == 1)
785 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
788 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
790 memoryAccessAttr, alignmentAttr);
796struct VectorStoreOpConverter final
797 :
public OpConversionPattern<vector::StoreOp> {
801 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
802 ConversionPatternRewriter &rewriter)
const override {
803 auto memrefType = storeOp.getMemRefType();
805 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
807 return rewriter.notifyMatchFailure(
808 storeOp,
"expected spirv.storage_class memory space");
810 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
811 auto loc = storeOp.getLoc();
814 adaptor.getIndices(), loc, rewriter);
816 return rewriter.notifyMatchFailure(
817 storeOp,
"failed to get memref element pointer");
819 std::optional<uint64_t> alignment = storeOp.getAlignment();
820 if (alignment > std::numeric_limits<uint32_t>::max()) {
821 return rewriter.notifyMatchFailure(storeOp,
822 "invalid alignment requirement");
825 spirv::StorageClass storageClass = attr.getValue();
826 auto vectorType = storeOp.getVectorType();
832 Value castedAccessChain =
833 (vectorType.getNumElements() == 1)
835 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
838 auto memoryAccess = spirv::MemoryAccess::None;
839 spirv::MemoryAccessAttr memoryAccessAttr;
840 IntegerAttr alignmentAttr;
841 if (alignment.has_value()) {
842 memoryAccess |= spirv::MemoryAccess::Aligned;
844 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
845 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
848 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
849 storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
856struct VectorReductionToIntDotProd final
860 LogicalResult matchAndRewrite(vector::ReductionOp op,
861 PatternRewriter &rewriter)
const override {
862 if (op.getKind() != vector::CombiningKind::ADD)
865 auto resultType = dyn_cast<IntegerType>(op.getType());
870 if (!llvm::is_contained({32, 64}, resultBitwidth))
873 VectorType inVecTy = op.getSourceVectorType();
874 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
875 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
878 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
881 op,
"reduction operand is not 'arith.muli'");
883 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
884 spirv::SDotAccSatOp,
false>(op,
mul, rewriter)))
887 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
888 spirv::UDotAccSatOp,
false>(op,
mul, rewriter)))
891 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
892 spirv::SUDotAccSatOp,
false>(op,
mul, rewriter)))
895 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
896 spirv::SUDotAccSatOp,
true>(op,
mul, rewriter)))
903 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
904 typename DotAccOp,
bool SwapOperands>
905 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp
mul,
906 PatternRewriter &rewriter) {
907 auto lhs =
mul.getLhs().getDefiningOp<LhsExtensionOp>();
910 Value lhsIn =
lhs.getIn();
911 auto lhsInType = cast<VectorType>(lhsIn.
getType());
912 if (!lhsInType.getElementType().isInteger(8))
915 auto rhs =
mul.getRhs().getDefiningOp<RhsExtensionOp>();
918 Value rhsIn =
rhs.getIn();
919 auto rhsInType = cast<VectorType>(rhsIn.
getType());
920 if (!rhsInType.getElementType().isInteger(8))
923 if (op.getSourceVectorType().getNumElements() == 3) {
924 IntegerType i8Type = rewriter.
getI8Type();
925 auto v4i8Type = VectorType::get({4}, i8Type);
926 Location loc = op.getLoc();
927 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
928 lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
930 rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
937 std::swap(lhsIn, rhsIn);
939 if (Value acc = op.getAcc()) {
951struct VectorReductionToFPDotProd final
952 : OpConversionPattern<vector::ReductionOp> {
956 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
957 ConversionPatternRewriter &rewriter)
const override {
958 if (op.getKind() != vector::CombiningKind::ADD)
959 return rewriter.notifyMatchFailure(op,
"combining kind is not 'add'");
961 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
963 return rewriter.notifyMatchFailure(op,
"result is not a float");
965 Value vec = adaptor.getVector();
966 Value acc = adaptor.getAcc();
968 auto vectorType = dyn_cast<VectorType>(vec.
getType());
970 assert(isa<FloatType>(vec.
getType()) &&
971 "Expected the vector to be scalarized");
973 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
977 rewriter.replaceOp(op, vec);
981 Location loc = op.getLoc();
992 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
993 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
994 rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
999 Value res = spirv::DotOp::create(rewriter, loc, resultType,
lhs,
rhs);
1001 res = spirv::FAddOp::create(rewriter, loc, acc, res);
1003 rewriter.replaceOp(op, res);
1008struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
1012 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1013 ConversionPatternRewriter &rewriter)
const override {
1014 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1015 Type dstType = typeConverter.convertType(stepOp.getType());
1019 Location loc = stepOp.getLoc();
1020 int64_t numElements = stepOp.getType().getNumElements();
1022 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
1026 if (numElements == 1) {
1027 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
1028 rewriter.replaceOp(stepOp, zero);
1032 SmallVector<Value> source;
1033 source.reserve(numElements);
1034 for (int64_t i = 0; i < numElements; ++i) {
1035 Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1037 spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1038 source.push_back(constOp);
1040 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1046struct VectorToElementOpConvert final
1047 : OpConversionPattern<vector::ToElementsOp> {
1051 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1052 ConversionPatternRewriter &rewriter)
const override {
1054 SmallVector<Value> results(toElementsOp->getNumResults());
1055 Location loc = toElementsOp.getLoc();
1060 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1061 results[0] = adaptor.getSource();
1062 rewriter.replaceOp(toElementsOp, results);
1066 Type srcElementType = toElementsOp.getElements().getType().front();
1067 Type elementType = getTypeConverter()->convertType(srcElementType);
1069 return rewriter.notifyMatchFailure(
1071 llvm::formatv(
"failed to convert element type '{0}' to SPIR-V",
1074 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1077 if (element.use_empty())
1080 Value
result = spirv::CompositeExtractOp::create(
1081 rewriter, loc, elementType, adaptor.getSource(),
1082 rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1086 rewriter.replaceOp(toElementsOp, results);
1092#define CL_INT_MAX_MIN_OPS \
1093 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1095#define GL_INT_MAX_MIN_OPS \
1096 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1098#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1099#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1104 VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1105 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1106 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1107 VectorToElementOpConvert, VectorInsertOpConvert,
1108 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1109 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1110 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1111 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1112 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1113 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1114 VectorScalarBroadcastPattern, VectorLoadOpConverter,
1115 VectorStoreOpConverter, VectorStepOpConvert>(
1120 patterns.
add<VectorReductionToFPDotProd>(typeConverter, patterns.
getContext(),
1126 patterns.
add<VectorReductionToIntDotProd>(patterns.
getContext());
static constexpr unsigned getNumBits()
#define INT_CASE(kind, iop)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...