28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
43 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51 if (
auto vectorType = dyn_cast<VectorType>(type))
52 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
58struct VectorShapeCast final :
public OpConversionPattern<vector::ShapeCastOp> {
62 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const override {
64 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
70 if (dstType == adaptor.getSource().getType() ||
71 shapeCastOp.getResultVectorType().getNumElements() == 1) {
72 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
81struct VectorBitcastConvert final
82 :
public OpConversionPattern<vector::BitCastOp> {
86 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter)
const override {
88 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
92 if (dstType == adaptor.getSource().getType()) {
93 rewriter.replaceOp(bitcastOp, adaptor.getSource());
100 Type srcType = adaptor.getSource().getType();
102 return rewriter.notifyMatchFailure(
104 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
108 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
109 adaptor.getSource());
114struct VectorBroadcastConvert final
115 :
public OpConversionPattern<vector::BroadcastOp> {
119 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
122 getTypeConverter()->convertType(castOp.getResultVectorType());
126 if (isa<spirv::ScalarType>(resultType)) {
127 rewriter.replaceOp(castOp, adaptor.getSource());
131 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
132 adaptor.getSource());
133 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
145static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147 int64_t kPoisonIndex,
unsigned vectorSize) {
148 if (llvm::isPowerOf2_32(vectorSize)) {
149 Value inBoundsMask = spirv::ConstantOp::create(
150 rewriter, loc, dynamicIndex.
getType(),
151 rewriter.getIntegerAttr(dynamicIndex.
getType(), vectorSize - 1));
152 return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
155 Value poisonIndex = spirv::ConstantOp::create(
156 rewriter, loc, dynamicIndex.
getType(),
157 rewriter.getIntegerAttr(dynamicIndex.
getType(), kPoisonIndex));
159 spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
160 return spirv::SelectOp::create(
161 rewriter, loc, cmpResult,
162 spirv::ConstantOp::getZero(dynamicIndex.
getType(), loc, rewriter),
166struct VectorExtractOpConvert final
167 :
public OpConversionPattern<vector::ExtractOp> {
171 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter)
const override {
173 Type dstType = getTypeConverter()->convertType(extractOp.getType());
177 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
178 rewriter.replaceOp(extractOp, adaptor.getSource());
182 if (std::optional<int64_t>
id =
184 if (
id == vector::ExtractOp::kPoisonIndex)
185 return rewriter.notifyMatchFailure(
187 "Static use of poison index handled elsewhere (folded to poison)");
188 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
189 extractOp, dstType, adaptor.getSource(),
190 rewriter.getI32ArrayAttr(
id.value()));
192 Value sanitizedIndex = sanitizeDynamicIndex(
193 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
194 vector::ExtractOp::kPoisonIndex,
195 extractOp.getSourceVectorType().getNumElements());
196 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
197 extractOp, dstType, adaptor.getSource(), sanitizedIndex);
203struct VectorExtractStridedSliceOpConvert final
204 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
208 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter)
const override {
210 Type dstType = getTypeConverter()->convertType(extractOp.getType());
220 Value srcVector = adaptor.getOperands().front();
223 if (isa<spirv::ScalarType>(dstType)) {
224 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
229 SmallVector<int32_t, 2>
indices(size);
232 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233 extractOp, dstType, srcVector, srcVector,
234 rewriter.getI32ArrayAttr(
indices));
240template <
class SPIRVFMAOp>
241struct VectorFmaOpConvert final :
public OpConversionPattern<vector::FMAOp> {
245 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
250 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
251 adaptor.getRhs(), adaptor.getAcc());
256struct VectorFromElementsOpConvert final
257 :
public OpConversionPattern<vector::FromElementsOp> {
261 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 Type resultType = getTypeConverter()->convertType(op.getType());
267 if (isa<spirv::ScalarType>(resultType)) {
270 rewriter.replaceOp(op, elements[0]);
275 assert(cast<VectorType>(resultType).getRank() == 1);
276 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
282struct VectorInsertOpConvert final
283 :
public OpConversionPattern<vector::InsertOp> {
287 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 if (isa<VectorType>(insertOp.getValueToStoreType()))
290 return rewriter.notifyMatchFailure(insertOp,
"unsupported vector source");
291 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
292 return rewriter.notifyMatchFailure(insertOp,
293 "unsupported dest vector type");
296 if (insertOp.getValueToStoreType().isIntOrFloat() &&
297 insertOp.getDestVectorType().getNumElements() == 1) {
298 rewriter.replaceOp(insertOp, adaptor.getValueToStore());
302 if (std::optional<int64_t>
id =
304 if (
id == vector::InsertOp::kPoisonIndex)
305 return rewriter.notifyMatchFailure(
307 "Static use of poison index handled elsewhere (folded to poison)");
308 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
309 insertOp, adaptor.getValueToStore(), adaptor.getDest(),
id.value());
311 Value sanitizedIndex = sanitizeDynamicIndex(
312 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
313 vector::InsertOp::kPoisonIndex,
314 insertOp.getDestVectorType().getNumElements());
315 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
316 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
323struct VectorInsertStridedSliceOpConvert final
324 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
328 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter)
const override {
330 Value srcVector = adaptor.getOperands().front();
331 Value dstVector = adaptor.getOperands().back();
338 if (isa<spirv::ScalarType>(srcVector.
getType())) {
339 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
340 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
341 insertOp, dstVector.
getType(), srcVector, dstVector,
342 rewriter.getI32ArrayAttr(offset));
346 uint64_t totalSize = cast<VectorType>(dstVector.
getType()).getNumElements();
347 uint64_t insertSize =
348 cast<VectorType>(srcVector.
getType()).getNumElements();
350 SmallVector<int32_t, 2>
indices(totalSize);
352 std::iota(
indices.begin() + offset,
indices.begin() + offset + insertSize,
355 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
356 insertOp, dstVector.
getType(), dstVector, srcVector,
357 rewriter.getI32ArrayAttr(
indices));
364 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
365 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
366 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
368 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
371 for (
int i = 0; i < numElements; ++i) {
372 values.push_back(spirv::CompositeExtractOp::create(
373 rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
374 rewriter.getI32ArrayAttr({i})));
377 values.push_back(
acc);
382struct ReductionRewriteInfo {
384 SmallVector<Value> extractedElements;
387FailureOr<ReductionRewriteInfo>
static getReductionInfo(
388 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
389 ConversionPatternRewriter &rewriter,
const TypeConverter &typeConverter) {
390 Type resultType = typeConverter.convertType(op.getType());
394 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
395 if (!srcVectorType || srcVectorType.getRank() != 1)
396 return rewriter.notifyMatchFailure(op,
"not a 1-D vector source");
399 extractAllElements(op, adaptor, srcVectorType, rewriter);
401 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
404template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
405 typename SPIRVSMinOp>
406struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
410 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter)
const override {
413 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
414 if (
failed(reductionInfo))
417 auto [resultType, extractedElements] = *reductionInfo;
418 Location loc = reduceOp->getLoc();
422 vector::CombiningKind kind = reduceOp.getKind();
424 if (kind == vector::CombiningKind::OR) {
425 Value
result = spirv::AnyOp::create(rewriter, loc, resultType,
426 adaptor.getVector());
427 if (Value acc = adaptor.getAcc())
428 result = spirv::LogicalOrOp::create(rewriter, loc, resultType,
result,
430 rewriter.replaceOp(reduceOp,
result);
434 if (kind == vector::CombiningKind::AND) {
435 Value
result = spirv::AllOp::create(rewriter, loc, resultType,
436 adaptor.getVector());
437 if (Value acc = adaptor.getAcc())
438 result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
440 rewriter.replaceOp(reduceOp,
result);
445 Value
result = extractedElements.front();
446 for (Value next : llvm::drop_begin(extractedElements)) {
447 switch (reduceOp.getKind()) {
449#define INT_AND_FLOAT_CASE(kind, iop, fop) \
450 case vector::CombiningKind::kind: \
451 if (isa<IntegerType>(resultType)) { \
452 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
454 assert(isa<FloatType>(resultType)); \
455 result = spirv::fop::create(rewriter, loc, resultType, result, next); \
459#define INT_OR_FLOAT_CASE(kind, fop) \
460 case vector::CombiningKind::kind: \
461 result = fop::create(rewriter, loc, resultType, result, next); \
464#define INT_CASE(kind, iop) \
465 case vector::CombiningKind::kind: \
466 assert(isa<IntegerType>(resultType)); \
467 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
481 return rewriter.notifyMatchFailure(reduceOp,
"not handled here");
483#undef INT_AND_FLOAT_CASE
484#undef INT_OR_FLOAT_CASE
488 rewriter.replaceOp(reduceOp,
result);
493template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
494struct VectorReductionFloatMinMax final
495 : OpConversionPattern<vector::ReductionOp> {
499 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
500 ConversionPatternRewriter &rewriter)
const override {
502 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
503 if (
failed(reductionInfo))
506 auto [resultType, extractedElements] = *reductionInfo;
507 Location loc = reduceOp->getLoc();
508 Value
result = extractedElements.front();
509 for (Value next : llvm::drop_begin(extractedElements)) {
510 switch (reduceOp.getKind()) {
512#define INT_OR_FLOAT_CASE(kind, fop) \
513 case vector::CombiningKind::kind: \
514 result = fop::create(rewriter, loc, resultType, result, next); \
523 return rewriter.notifyMatchFailure(reduceOp,
"not handled here");
525#undef INT_OR_FLOAT_CASE
528 rewriter.replaceOp(reduceOp,
result);
533class VectorScalarBroadcastPattern final
534 :
public OpConversionPattern<vector::BroadcastOp> {
539 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter)
const override {
541 if (isa<VectorType>(op.getSourceType())) {
542 return rewriter.notifyMatchFailure(
543 op,
"only conversion of 'broadcast from scalar' is supported");
545 Type dstType = getTypeConverter()->convertType(op.getType());
548 if (isa<spirv::ScalarType>(dstType)) {
549 rewriter.replaceOp(op, adaptor.getSource());
551 auto dstVecType = cast<VectorType>(dstType);
552 SmallVector<Value, 4> source(dstVecType.getNumElements(),
553 adaptor.getSource());
554 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
561struct VectorShuffleOpConvert final
562 :
public OpConversionPattern<vector::ShuffleOp> {
566 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
567 ConversionPatternRewriter &rewriter)
const override {
568 VectorType oldResultType = shuffleOp.getResultVectorType();
569 Type newResultType = getTypeConverter()->convertType(oldResultType);
571 return rewriter.notifyMatchFailure(shuffleOp,
572 "unsupported result vector type");
574 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
576 VectorType oldV1Type = shuffleOp.getV1VectorType();
577 VectorType oldV2Type = shuffleOp.getV2VectorType();
581 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
582 oldResultType.getNumElements() > 1) {
583 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
584 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
585 rewriter.getI32ArrayAttr(mask));
592 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
593 Value scalarOrVec, int32_t idx) -> Value {
594 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
595 return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
598 assert(idx == 0 &&
"Invalid scalar element index");
602 int32_t numV1Elems = oldV1Type.getNumElements();
603 SmallVector<Value> newOperands(mask.size());
604 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
605 Value vec = adaptor.getV1();
606 int32_t elementIdx = shuffleIdx;
607 if (elementIdx >= numV1Elems) {
608 vec = adaptor.getV2();
609 elementIdx -= numV1Elems;
612 newOperand = getElementAtIdx(vec, elementIdx);
616 if (newOperands.size() == 1) {
617 rewriter.replaceOp(shuffleOp, newOperands.front());
621 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
622 shuffleOp, newResultType, newOperands);
627struct VectorInterleaveOpConvert final
628 :
public OpConversionPattern<vector::InterleaveOp> {
632 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
635 VectorType oldResultType = interleaveOp.getResultVectorType();
636 Type newResultType = getTypeConverter()->convertType(oldResultType);
638 return rewriter.notifyMatchFailure(interleaveOp,
639 "unsupported result vector type");
642 VectorType sourceType = interleaveOp.getSourceVectorType();
643 int n = sourceType.getNumElements();
649 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
650 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
651 interleaveOp, newResultType, newOperands);
655 auto seq = llvm::seq<int64_t>(2 * n);
656 auto indices = llvm::map_to_vector(
657 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
660 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
661 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
662 rewriter.getI32ArrayAttr(
indices));
668struct VectorDeinterleaveOpConvert final
669 :
public OpConversionPattern<vector::DeinterleaveOp> {
673 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
674 ConversionPatternRewriter &rewriter)
const override {
677 VectorType oldResultType = deinterleaveOp.getResultVectorType();
678 Type newResultType = getTypeConverter()->convertType(oldResultType);
680 return rewriter.notifyMatchFailure(deinterleaveOp,
681 "unsupported result vector type");
683 Location loc = deinterleaveOp->getLoc();
686 Value sourceVector = adaptor.getSource();
687 VectorType sourceType = deinterleaveOp.getSourceVectorType();
688 int n = sourceType.getNumElements();
694 auto elem0 = spirv::CompositeExtractOp::create(
695 rewriter, loc, newResultType, sourceVector,
696 rewriter.getI32ArrayAttr({0}));
698 auto elem1 = spirv::CompositeExtractOp::create(
699 rewriter, loc, newResultType, sourceVector,
700 rewriter.getI32ArrayAttr({1}));
702 rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
707 auto seqEven = llvm::seq<int64_t>(n / 2);
709 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
712 auto seqOdd = llvm::seq<int64_t>(n / 2);
714 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
717 auto shuffleEven = spirv::VectorShuffleOp::create(
718 rewriter, loc, newResultType, sourceVector, sourceVector,
719 rewriter.getI32ArrayAttr(indicesEven));
721 auto shuffleOdd = spirv::VectorShuffleOp::create(
722 rewriter, loc, newResultType, sourceVector, sourceVector,
723 rewriter.getI32ArrayAttr(indicesOdd));
725 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
730struct VectorLoadOpConverter final
731 :
public OpConversionPattern<vector::LoadOp> {
735 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const override {
737 auto memrefType = loadOp.getMemRefType();
739 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
741 return rewriter.notifyMatchFailure(
742 loadOp,
"expected spirv.storage_class memory space");
744 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
745 auto loc = loadOp.getLoc();
748 adaptor.getIndices(), loc, rewriter);
750 return rewriter.notifyMatchFailure(
751 loadOp,
"failed to get memref element pointer");
753 spirv::StorageClass storageClass = attr.getValue();
754 auto vectorType = loadOp.getVectorType();
757 auto spirvVectorType = typeConverter.convertType(vectorType);
758 if (!spirvVectorType)
759 return rewriter.notifyMatchFailure(loadOp,
"unsupported vector type");
763 std::optional<uint64_t> alignment = loadOp.getAlignment();
764 if (alignment > std::numeric_limits<uint32_t>::max()) {
765 return rewriter.notifyMatchFailure(loadOp,
766 "invalid alignment requirement");
769 auto memoryAccess = spirv::MemoryAccess::None;
770 spirv::MemoryAccessAttr memoryAccessAttr;
771 IntegerAttr alignmentAttr;
772 if (alignment.has_value()) {
773 memoryAccess |= spirv::MemoryAccess::Aligned;
775 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
776 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
782 Value castedAccessChain =
783 (vectorType.getNumElements() == 1)
785 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
788 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
790 memoryAccessAttr, alignmentAttr);
796struct VectorStoreOpConverter final
797 :
public OpConversionPattern<vector::StoreOp> {
801 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
802 ConversionPatternRewriter &rewriter)
const override {
803 auto memrefType = storeOp.getMemRefType();
805 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
807 return rewriter.notifyMatchFailure(
808 storeOp,
"expected spirv.storage_class memory space");
810 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
811 auto loc = storeOp.getLoc();
814 adaptor.getIndices(), loc, rewriter);
816 return rewriter.notifyMatchFailure(
817 storeOp,
"failed to get memref element pointer");
819 std::optional<uint64_t> alignment = storeOp.getAlignment();
820 if (alignment > std::numeric_limits<uint32_t>::max()) {
821 return rewriter.notifyMatchFailure(storeOp,
822 "invalid alignment requirement");
825 spirv::StorageClass storageClass = attr.getValue();
826 auto vectorType = storeOp.getVectorType();
829 auto spirvVectorType = typeConverter.convertType(vectorType);
830 if (!spirvVectorType)
831 return rewriter.notifyMatchFailure(storeOp,
"unsupported vector type");
838 Value castedAccessChain =
839 (vectorType.getNumElements() == 1)
841 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
844 auto memoryAccess = spirv::MemoryAccess::None;
845 spirv::MemoryAccessAttr memoryAccessAttr;
846 IntegerAttr alignmentAttr;
847 if (alignment.has_value()) {
848 memoryAccess |= spirv::MemoryAccess::Aligned;
850 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
851 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
854 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
855 storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
862struct VectorReductionToIntDotProd final
866 LogicalResult matchAndRewrite(vector::ReductionOp op,
867 PatternRewriter &rewriter)
const override {
868 if (op.getKind() != vector::CombiningKind::ADD)
871 auto resultType = dyn_cast<IntegerType>(op.getType());
876 if (!llvm::is_contained({32, 64}, resultBitwidth))
879 VectorType inVecTy = op.getSourceVectorType();
880 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
881 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
884 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
887 op,
"reduction operand is not 'arith.muli'");
889 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
890 spirv::SDotAccSatOp,
false>(op,
mul, rewriter)))
893 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
894 spirv::UDotAccSatOp,
false>(op,
mul, rewriter)))
897 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
898 spirv::SUDotAccSatOp,
false>(op,
mul, rewriter)))
901 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
902 spirv::SUDotAccSatOp,
true>(op,
mul, rewriter)))
909 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
910 typename DotAccOp,
bool SwapOperands>
911 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp
mul,
912 PatternRewriter &rewriter) {
913 auto lhs =
mul.getLhs().getDefiningOp<LhsExtensionOp>();
916 Value lhsIn =
lhs.getIn();
917 auto lhsInType = cast<VectorType>(lhsIn.
getType());
918 if (!lhsInType.getElementType().isInteger(8))
921 auto rhs =
mul.getRhs().getDefiningOp<RhsExtensionOp>();
924 Value rhsIn =
rhs.getIn();
925 auto rhsInType = cast<VectorType>(rhsIn.
getType());
926 if (!rhsInType.getElementType().isInteger(8))
929 if (op.getSourceVectorType().getNumElements() == 3) {
930 IntegerType i8Type = rewriter.
getI8Type();
931 auto v4i8Type = VectorType::get({4}, i8Type);
932 Location loc = op.getLoc();
933 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
934 lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
936 rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
943 std::swap(lhsIn, rhsIn);
945 if (Value acc = op.getAcc()) {
957struct VectorReductionToFPDotProd final
958 : OpConversionPattern<vector::ReductionOp> {
962 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
963 ConversionPatternRewriter &rewriter)
const override {
964 if (op.getKind() != vector::CombiningKind::ADD)
965 return rewriter.notifyMatchFailure(op,
"combining kind is not 'add'");
967 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
969 return rewriter.notifyMatchFailure(op,
"result is not a float");
971 Value vec = adaptor.getVector();
972 Value acc = adaptor.getAcc();
974 auto vectorType = dyn_cast<VectorType>(vec.
getType());
976 assert(isa<FloatType>(vec.
getType()) &&
977 "Expected the vector to be scalarized");
979 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
983 rewriter.replaceOp(op, vec);
987 Location loc = op.getLoc();
998 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
999 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
1000 rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
1005 Value res = spirv::DotOp::create(rewriter, loc, resultType,
lhs,
rhs);
1007 res = spirv::FAddOp::create(rewriter, loc, acc, res);
1009 rewriter.replaceOp(op, res);
1014struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
1018 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1019 ConversionPatternRewriter &rewriter)
const override {
1020 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1021 Type dstType = typeConverter.convertType(stepOp.getType());
1025 Location loc = stepOp.getLoc();
1026 int64_t numElements = stepOp.getType().getNumElements();
1028 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
1032 if (numElements == 1) {
1033 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
1034 rewriter.replaceOp(stepOp, zero);
1038 SmallVector<Value> source;
1039 source.reserve(numElements);
1040 for (int64_t i = 0; i < numElements; ++i) {
1041 Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1043 spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1044 source.push_back(constOp);
1046 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1052struct VectorToElementOpConvert final
1053 : OpConversionPattern<vector::ToElementsOp> {
1057 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1058 ConversionPatternRewriter &rewriter)
const override {
1060 SmallVector<Value> results(toElementsOp->getNumResults());
1061 Location loc = toElementsOp.getLoc();
1066 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1067 results[0] = adaptor.getSource();
1068 rewriter.replaceOp(toElementsOp, results);
1072 Type srcElementType = toElementsOp.getElements().getType().front();
1073 Type elementType = getTypeConverter()->convertType(srcElementType);
1075 return rewriter.notifyMatchFailure(
1077 llvm::formatv(
"failed to convert element type '{0}' to SPIR-V",
1080 for (
auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1083 if (element.use_empty())
1086 Value
result = spirv::CompositeExtractOp::create(
1087 rewriter, loc, elementType, adaptor.getSource(),
1088 rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1092 rewriter.replaceOp(toElementsOp, results);
1098#define CL_INT_MAX_MIN_OPS \
1099 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1101#define GL_INT_MAX_MIN_OPS \
1102 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1104#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1105#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1110 VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1111 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1112 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1113 VectorToElementOpConvert, VectorInsertOpConvert,
1114 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1115 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1116 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1117 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1118 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1119 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1120 VectorScalarBroadcastPattern, VectorLoadOpConverter,
1121 VectorStoreOpConverter, VectorStepOpConvert>(
1126 patterns.
add<VectorReductionToFPDotProd>(typeConverter, patterns.
getContext(),
1132 patterns.
add<VectorReductionToIntDotProd>(patterns.
getContext());
static constexpr unsigned getNumBits()
#define INT_CASE(kind, iop)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...