28#define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (
auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
175 if (typeConverter.
allows(spirv::Capability::Kernel)) {
176 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
177 return arrayType.getElementType();
181 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
182 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
183 return arrayType.getElementType();
184 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
192 auto one = spirv::ConstantOp::getZero(srcInt.
getType(), loc, builder);
193 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
207class AllocaOpPattern final :
public OpConversionPattern<memref::AllocaOp> {
212 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter)
const override;
220class AllocOpPattern final :
public OpConversionPattern<memref::AllocOp> {
225 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter)
const override;
230class AtomicRMWOpPattern final
231 :
public OpConversionPattern<memref::AtomicRMWOp> {
236 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter)
const override;
242class DeallocOpPattern final :
public OpConversionPattern<memref::DeallocOp> {
247 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
248 ConversionPatternRewriter &rewriter)
const override;
252class IntLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
257 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter)
const override;
262class LoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
267 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
268 ConversionPatternRewriter &rewriter)
const override;
272class ImageLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
277 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override;
282class IntStoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
287 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override;
292class MemorySpaceCastOpPattern final
293 :
public OpConversionPattern<memref::MemorySpaceCastOp> {
298 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter)
const override;
303class StoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
308 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
309 ConversionPatternRewriter &rewriter)
const override;
312class ReinterpretCastPattern final
313 :
public OpConversionPattern<memref::ReinterpretCastOp> {
318 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter)
const override;
322class CastPattern final :
public OpConversionPattern<memref::CastOp> {
327 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter)
const override {
329 Value src = adaptor.getSource();
332 const TypeConverter *converter = getTypeConverter();
333 Type dstType = converter->convertType(op.getType());
334 if (srcType != dstType)
335 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
336 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
339 rewriter.replaceOp(op, src);
345class ExtractAlignedPointerAsIndexOpPattern final
346 :
public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
351 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
353 ConversionPatternRewriter &rewriter)
const override;
362AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const {
364 MemRefType allocType = allocaOp.getType();
366 return rewriter.notifyMatchFailure(allocaOp,
"unhandled allocation type");
369 Type spirvType = getTypeConverter()->convertType(allocType);
371 return rewriter.notifyMatchFailure(allocaOp,
"type conversion failed");
373 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
374 spirv::StorageClass::Function,
384AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
385 ConversionPatternRewriter &rewriter)
const {
386 MemRefType allocType = operation.getType();
388 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
391 Type spirvType = getTypeConverter()->convertType(allocType);
393 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
400 Location loc = operation.getLoc();
401 spirv::GlobalVariableOp varOp;
403 OpBuilder::InsertionGuard guard(rewriter);
405 rewriter.setInsertionPointToStart(&entryBlock);
406 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
407 std::string varName =
408 std::string(
"__workgroup_mem__") +
409 std::to_string(std::distance(varOps.begin(), varOps.end()));
410 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
415 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
424AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
426 ConversionPatternRewriter &rewriter)
const {
427 if (isa<FloatType>(atomicOp.getType()))
428 return rewriter.notifyMatchFailure(atomicOp,
429 "unimplemented floating-point case");
431 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
434 return rewriter.notifyMatchFailure(atomicOp,
435 "unsupported memref memory space");
437 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
438 Type resultType = typeConverter.convertType(atomicOp.getType());
440 return rewriter.notifyMatchFailure(atomicOp,
441 "failed to convert result type");
443 auto loc = atomicOp.getLoc();
446 adaptor.getIndices(), loc, rewriter);
454 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
455 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
457 return rewriter.notifyMatchFailure(atomicOp,
458 "failed to convert memref type");
460 Type pointeeType = pointerType.getPointeeType();
461 auto dstType = dyn_cast<IntegerType>(
464 return rewriter.notifyMatchFailure(
465 atomicOp,
"failed to determine destination element type");
467 int dstBits =
static_cast<int>(dstType.getWidth());
468 assert(dstBits % srcBits == 0);
472 if (srcBits == dstBits) {
473#define ATOMIC_CASE(kind, spirvOp) \
474 case arith::AtomicRMWKind::kind: \
475 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
476 atomicOp, resultType, ptr, *scope, \
477 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
480 switch (atomicOp.getKind()) {
489 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
504 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
505 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
506 return rewriter.notifyMatchFailure(
508 "atomic op on sub-element-width types is only supported for ori/andi");
513 if (typeConverter.allows(spirv::Capability::Kernel))
514 return rewriter.notifyMatchFailure(
516 "sub-element-width atomic ops unsupported with Kernel capability");
518 auto accessChainOp = ptr.
getDefiningOp<spirv::AccessChainOp>();
524 assert(accessChainOp.getIndices().size() == 2);
525 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
528 srcBits, dstBits, rewriter);
530 switch (atomicOp.getKind()) {
531 case arith::AtomicRMWKind::ori: {
534 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
535 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
537 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
538 result = spirv::AtomicOrOp::create(
539 rewriter, loc, dstType, adjustedPtr, *scope,
540 spirv::MemorySemantics::AcquireRelease, storeVal);
543 case arith::AtomicRMWKind::andi: {
547 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
548 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
550 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
551 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
552 loc, dstType, elemMask, offset);
553 Value invertedElemMask =
554 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
555 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
557 result = spirv::AtomicAndOp::create(
558 rewriter, loc, dstType, adjustedPtr, *scope,
559 spirv::MemorySemantics::AcquireRelease, mask);
563 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
568 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
570 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
571 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
573 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
574 rewriter.replaceOp(atomicOp,
result);
584DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
586 ConversionPatternRewriter &rewriter)
const {
587 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
589 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
590 rewriter.eraseOp(operation);
605static FailureOr<MemoryRequirements>
607 uint64_t preferredAlignment) {
608 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
614 auto memoryAccess = spirv::MemoryAccess::None;
616 memoryAccess = spirv::MemoryAccess::Nontemporal;
619 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
620 bool mayOmitAlignment =
621 !preferredAlignment &&
622 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
623 if (mayOmitAlignment) {
624 if (memoryAccess == spirv::MemoryAccess::None) {
633 std::optional<int64_t> sizeInBytes;
634 Type rawPointeeType = ptrType.getPointeeType();
635 if (
auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
637 sizeInBytes = scalarType.getSizeInBytes();
638 }
else if (
auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
641 if (
auto scalarElem =
642 dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
643 if (
auto elemSize = scalarElem.getSizeInBytes())
644 sizeInBytes = *elemSize * vecType.getNumElements();
648 if (!sizeInBytes.has_value())
651 memoryAccess |= spirv::MemoryAccess::Aligned;
652 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
653 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
654 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
661template <
class LoadOrStoreOp>
662static FailureOr<MemoryRequirements>
665 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
666 "Must be called on either memref::LoadOp or memref::StoreOp");
669 loadOrStoreOp.getNontemporal(),
670 loadOrStoreOp.getAlignment().value_or(0));
674IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter)
const {
676 auto loc = loadOp.getLoc();
677 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
678 if (!memrefType.getElementType().isSignlessInteger())
681 auto memorySpaceAttr =
682 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
683 if (!memorySpaceAttr)
684 return rewriter.notifyMatchFailure(
685 loadOp,
"missing memory space SPIR-V storage class attribute");
687 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
688 return rewriter.notifyMatchFailure(
690 "failed to lower memref in image storage class to storage buffer");
692 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
695 adaptor.getIndices(), loc, rewriter);
700 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
701 bool isBool = srcBits == 1;
703 srcBits = typeConverter.getOptions().boolNumBits;
705 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
707 return rewriter.notifyMatchFailure(loadOp,
"failed to convert memref type");
709 Type pointeeType = pointerType.getPointeeType();
712 assert(dstBits % srcBits == 0);
716 if (srcBits == dstBits) {
718 if (
failed(memoryRequirements))
719 return rewriter.notifyMatchFailure(
720 loadOp,
"failed to determine memory requirements");
722 auto [memoryAccess, alignment] = *memoryRequirements;
723 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
724 memoryAccess, alignment);
727 rewriter.replaceOp(loadOp, loadVal);
733 if (typeConverter.allows(spirv::Capability::Kernel))
736 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
743 assert(accessChainOp.getIndices().size() == 2);
745 srcBits, dstBits, rewriter);
747 if (
failed(memoryRequirements))
748 return rewriter.notifyMatchFailure(
749 loadOp,
"failed to determine memory requirements");
751 auto [memoryAccess, alignment] = *memoryRequirements;
752 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
753 memoryAccess, alignment);
757 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
759 Value
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
760 loc, spvLoadOp.
getType(), spvLoadOp, offset);
763 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
764 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
766 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
771 IntegerAttr shiftValueAttr =
772 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
774 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
775 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
777 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
780 rewriter.replaceOp(loadOp,
result);
782 assert(accessChainOp.use_empty());
783 rewriter.eraseOp(accessChainOp);
789LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
790 ConversionPatternRewriter &rewriter)
const {
791 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
792 if (memrefType.getElementType().isSignlessInteger())
795 auto memorySpaceAttr =
796 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
797 if (!memorySpaceAttr)
798 return rewriter.notifyMatchFailure(
799 loadOp,
"missing memory space SPIR-V storage class attribute");
801 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
802 return rewriter.notifyMatchFailure(
804 "failed to lower memref in image storage class to storage buffer");
807 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
808 adaptor.getIndices(), loadOp.getLoc(), rewriter);
814 if (
failed(memoryRequirements))
815 return rewriter.notifyMatchFailure(
816 loadOp,
"failed to determine memory requirements");
818 auto [memoryAccess, alignment] = *memoryRequirements;
819 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
824template <
typename OpAdaptor>
825static FailureOr<SmallVector<Value>>
827 ConversionPatternRewriter &rewriter) {
834 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
836 return rewriter.notifyMatchFailure(
838 "Cannot lower memrefs with memory layout which is not a permutation");
844 for (
unsigned dim = 0; dim < dimCount; ++dim)
850 return llvm::to_vector(llvm::reverse(coords));
854ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter)
const {
856 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
858 auto memorySpaceAttr =
859 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
860 if (!memorySpaceAttr)
861 return rewriter.notifyMatchFailure(
862 loadOp,
"missing memory space SPIR-V storage class attribute");
864 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
865 return rewriter.notifyMatchFailure(
866 loadOp,
"failed to lower memref in non-image storage class to image");
868 Value loadPtr = adaptor.getMemref();
870 if (
failed(memoryRequirements))
871 return rewriter.notifyMatchFailure(
872 loadOp,
"failed to determine memory requirements");
874 const auto [memoryAccess, alignment] = *memoryRequirements;
876 if (!loadOp.getMemRefType().hasRank())
877 return rewriter.notifyMatchFailure(
878 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
883 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
884 return rewriter.notifyMatchFailure(
886 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
893 auto convertedPointeeType = cast<spirv::PointerType>(
894 getTypeConverter()->convertType(loadOp.getMemRefType()));
895 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
896 return rewriter.notifyMatchFailure(loadOp,
897 "cannot lower memrefs which do not "
898 "convert to SPIR-V sampled images");
901 Location loc = loadOp->getLoc();
903 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
905 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
909 if (memrefType.getRank() == 1) {
910 coords = adaptor.getIndices()[0];
912 FailureOr<SmallVector<Value>> maybeCoords =
916 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
917 adaptor.getIndices().
getType()[0]);
918 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
919 maybeCoords.value());
923 auto resultVectorType = VectorType::get({4}, loadOp.getType());
924 auto fetchOp = spirv::ImageFetchOp::create(
925 rewriter, loc, resultVectorType, imageOp, coords,
926 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
931 auto compositeExtractOp =
932 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
934 rewriter.replaceOp(loadOp, compositeExtractOp);
939IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
940 ConversionPatternRewriter &rewriter)
const {
941 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
942 if (!memrefType.getElementType().isSignlessInteger())
943 return rewriter.notifyMatchFailure(storeOp,
944 "element type is not a signless int");
946 auto loc = storeOp.getLoc();
947 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
950 adaptor.getIndices(), loc, rewriter);
953 return rewriter.notifyMatchFailure(
954 storeOp,
"failed to convert element pointer type");
956 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
958 bool isBool = srcBits == 1;
960 srcBits = typeConverter.getOptions().boolNumBits;
962 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
964 return rewriter.notifyMatchFailure(storeOp,
965 "failed to convert memref type");
967 Type pointeeType = pointerType.getPointeeType();
968 auto dstType = dyn_cast<IntegerType>(
971 return rewriter.notifyMatchFailure(
972 storeOp,
"failed to determine destination element type");
974 int dstBits =
static_cast<int>(dstType.getWidth());
975 assert(dstBits % srcBits == 0);
977 if (srcBits == dstBits) {
979 if (
failed(memoryRequirements))
980 return rewriter.notifyMatchFailure(
981 storeOp,
"failed to determine memory requirements");
983 auto [memoryAccess, alignment] = *memoryRequirements;
984 Value storeVal = adaptor.getValue();
987 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
988 memoryAccess, alignment);
994 if (typeConverter.allows(spirv::Capability::Kernel))
997 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
1012 assert(accessChainOp.getIndices().size() == 2);
1013 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1018 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1019 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1020 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1021 loc, dstType, mask, offset);
1023 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1025 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1027 srcBits, dstBits, rewriter);
1030 return rewriter.notifyMatchFailure(storeOp,
"atomic scope not available");
1032 Value
result = spirv::AtomicAndOp::create(
1033 rewriter, loc, dstType, adjustedPtr, *scope,
1034 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
1035 result = spirv::AtomicOrOp::create(
1036 rewriter, loc, dstType, adjustedPtr, *scope,
1037 spirv::MemorySemantics::AcquireRelease, storeVal);
1043 rewriter.eraseOp(storeOp);
1045 assert(accessChainOp.use_empty());
1046 rewriter.eraseOp(accessChainOp);
1055LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1056 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter)
const {
1058 Location loc = addrCastOp.getLoc();
1059 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1060 if (!typeConverter.allows(spirv::Capability::Kernel))
1061 return rewriter.notifyMatchFailure(
1062 loc,
"address space casts require kernel capability");
1064 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1066 return rewriter.notifyMatchFailure(
1067 loc,
"SPIR-V lowering requires ranked memref types");
1068 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1070 auto sourceStorageClassAttr =
1071 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1072 if (!sourceStorageClassAttr)
1073 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &
diag) {
1074 diag <<
"source address space " << sourceType.getMemorySpace()
1075 <<
" must be a SPIR-V storage class";
1077 auto resultStorageClassAttr =
1078 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1079 if (!resultStorageClassAttr)
1080 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &
diag) {
1081 diag <<
"result address space " << resultType.getMemorySpace()
1082 <<
" must be a SPIR-V storage class";
1085 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1086 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1088 Value
result = adaptor.getSource();
1089 Type resultPtrType = typeConverter.convertType(resultType);
1091 return rewriter.notifyMatchFailure(addrCastOp,
1092 "failed to convert memref type");
1094 Type genericPtrType = resultPtrType;
1102 if (sourceSc != spirv::StorageClass::Generic &&
1103 resultSc != spirv::StorageClass::Generic) {
1104 Type intermediateType =
1105 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1106 sourceType.getLayout(),
1107 rewriter.getAttr<spirv::StorageClassAttr>(
1108 spirv::StorageClass::Generic));
1109 genericPtrType = typeConverter.convertType(intermediateType);
1111 if (sourceSc != spirv::StorageClass::Generic) {
1112 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1115 if (resultSc != spirv::StorageClass::Generic) {
1117 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType,
result);
1119 rewriter.replaceOp(addrCastOp,
result);
1124StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter)
const {
1126 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1127 if (memrefType.getElementType().isSignlessInteger())
1128 return rewriter.notifyMatchFailure(storeOp,
"signless int");
1130 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1131 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1134 return rewriter.notifyMatchFailure(storeOp,
"type conversion failed");
1137 if (
failed(memoryRequirements))
1138 return rewriter.notifyMatchFailure(
1139 storeOp,
"failed to determine memory requirements");
1141 auto [memoryAccess, alignment] = *memoryRequirements;
1142 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1143 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1147LogicalResult ReinterpretCastPattern::matchAndRewrite(
1148 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter)
const {
1150 Value src = adaptor.getSource();
1151 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1154 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1158 const TypeConverter *converter = getTypeConverter();
1160 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1161 if (dstType != srcType)
1162 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1163 diag <<
"invalid dst type " << op.getType();
1166 OpFoldResult offset =
1167 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1170 rewriter.replaceOp(op, src);
1174 Type intType = converter->convertType(rewriter.getIndexType());
1176 return rewriter.notifyMatchFailure(op,
"failed to convert index type");
1178 Location loc = op.getLoc();
1179 auto offsetValue = [&]() -> Value {
1180 if (
auto val = dyn_cast<Value>(offset))
1183 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1184 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1185 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1188 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1197LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1198 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1199 ConversionPatternRewriter &rewriter)
const {
1200 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1201 Type indexType = typeConverter.getIndexType();
1202 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1203 adaptor.getSource());
1214 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1215 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1216 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1217 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1218 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess