28#define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
152 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153 switch (sc.getValue()) {
154 case spirv::StorageClass::StorageBuffer:
155 return spirv::Scope::Device;
156 case spirv::StorageClass::Workgroup:
157 return spirv::Scope::Workgroup;
169 auto one = spirv::ConstantOp::getZero(srcInt.
getType(), loc, builder);
170 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
184class AllocaOpPattern final :
public OpConversionPattern<memref::AllocaOp> {
189 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190 ConversionPatternRewriter &rewriter)
const override;
197class AllocOpPattern final :
public OpConversionPattern<memref::AllocOp> {
202 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter)
const override;
207class AtomicRMWOpPattern final
208 :
public OpConversionPattern<memref::AtomicRMWOp> {
213 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter)
const override;
219class DeallocOpPattern final :
public OpConversionPattern<memref::DeallocOp> {
224 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225 ConversionPatternRewriter &rewriter)
const override;
229class IntLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
234 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter)
const override;
239class LoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
244 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter)
const override;
249class ImageLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
254 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter)
const override;
259class IntStoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
264 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter)
const override;
269class MemorySpaceCastOpPattern final
270 :
public OpConversionPattern<memref::MemorySpaceCastOp> {
275 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
276 ConversionPatternRewriter &rewriter)
const override;
280class StoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
285 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override;
289class ReinterpretCastPattern final
290 :
public OpConversionPattern<memref::ReinterpretCastOp> {
295 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const override;
299class CastPattern final :
public OpConversionPattern<memref::CastOp> {
304 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const override {
306 Value src = adaptor.getSource();
309 const TypeConverter *converter = getTypeConverter();
310 Type dstType = converter->convertType(op.getType());
311 if (srcType != dstType)
312 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
313 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
316 rewriter.replaceOp(op, src);
322class ExtractAlignedPointerAsIndexOpPattern final
323 :
public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
328 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
330 ConversionPatternRewriter &rewriter)
const override;
339AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter)
const {
341 MemRefType allocType = allocaOp.getType();
343 return rewriter.notifyMatchFailure(allocaOp,
"unhandled allocation type");
346 Type spirvType = getTypeConverter()->convertType(allocType);
348 return rewriter.notifyMatchFailure(allocaOp,
"type conversion failed");
350 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
351 spirv::StorageClass::Function,
361AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter)
const {
363 MemRefType allocType = operation.getType();
365 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
368 Type spirvType = getTypeConverter()->convertType(allocType);
370 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
377 Location loc = operation.getLoc();
378 spirv::GlobalVariableOp varOp;
380 OpBuilder::InsertionGuard guard(rewriter);
382 rewriter.setInsertionPointToStart(&entryBlock);
383 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
384 std::string varName =
385 std::string(
"__workgroup_mem__") +
386 std::to_string(std::distance(varOps.begin(), varOps.end()));
387 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
392 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
401AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
403 ConversionPatternRewriter &rewriter)
const {
404 if (isa<FloatType>(atomicOp.getType()))
405 return rewriter.notifyMatchFailure(atomicOp,
406 "unimplemented floating-point case");
408 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
411 return rewriter.notifyMatchFailure(atomicOp,
412 "unsupported memref memory space");
414 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
415 Type resultType = typeConverter.convertType(atomicOp.getType());
417 return rewriter.notifyMatchFailure(atomicOp,
418 "failed to convert result type");
420 auto loc = atomicOp.getLoc();
423 adaptor.getIndices(), loc, rewriter);
428#define ATOMIC_CASE(kind, spirvOp) \
429 case arith::AtomicRMWKind::kind: \
430 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
431 atomicOp, resultType, ptr, *scope, \
432 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
435 switch (atomicOp.getKind()) {
444 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
457DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
459 ConversionPatternRewriter &rewriter)
const {
460 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
462 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
463 rewriter.eraseOp(operation);
478static FailureOr<MemoryRequirements>
480 uint64_t preferredAlignment) {
481 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
487 auto memoryAccess = spirv::MemoryAccess::None;
489 memoryAccess = spirv::MemoryAccess::Nontemporal;
492 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
493 bool mayOmitAlignment =
494 !preferredAlignment &&
495 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
496 if (mayOmitAlignment) {
497 if (memoryAccess == spirv::MemoryAccess::None) {
506 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
511 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
512 if (!sizeInBytes.has_value())
515 memoryAccess |= spirv::MemoryAccess::Aligned;
516 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
517 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
525template <
class LoadOrStoreOp>
526static FailureOr<MemoryRequirements>
529 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
530 "Must be called on either memref::LoadOp or memref::StoreOp");
533 loadOrStoreOp.getNontemporal(),
534 loadOrStoreOp.getAlignment().value_or(0));
538IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const {
540 auto loc = loadOp.getLoc();
541 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
542 if (!memrefType.getElementType().isSignlessInteger())
545 auto memorySpaceAttr =
546 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
547 if (!memorySpaceAttr)
548 return rewriter.notifyMatchFailure(
549 loadOp,
"missing memory space SPIR-V storage class attribute");
551 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
552 return rewriter.notifyMatchFailure(
554 "failed to lower memref in image storage class to storage buffer");
556 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
559 adaptor.getIndices(), loc, rewriter);
564 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
565 bool isBool = srcBits == 1;
567 srcBits = typeConverter.getOptions().boolNumBits;
569 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
571 return rewriter.notifyMatchFailure(loadOp,
"failed to convert memref type");
573 Type pointeeType = pointerType.getPointeeType();
575 if (typeConverter.allows(spirv::Capability::Kernel)) {
576 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
577 dstType = arrayType.getElementType();
579 dstType = pointeeType;
582 Type structElemType =
583 cast<spirv::StructType>(pointeeType).getElementType(0);
584 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
585 dstType = arrayType.getElementType();
587 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
590 assert(dstBits % srcBits == 0);
594 if (srcBits == dstBits) {
596 if (
failed(memoryRequirements))
597 return rewriter.notifyMatchFailure(
598 loadOp,
"failed to determine memory requirements");
600 auto [memoryAccess, alignment] = *memoryRequirements;
601 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
602 memoryAccess, alignment);
605 rewriter.replaceOp(loadOp, loadVal);
611 if (typeConverter.allows(spirv::Capability::Kernel))
614 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
621 assert(accessChainOp.getIndices().size() == 2);
623 srcBits, dstBits, rewriter);
625 if (
failed(memoryRequirements))
626 return rewriter.notifyMatchFailure(
627 loadOp,
"failed to determine memory requirements");
629 auto [memoryAccess, alignment] = *memoryRequirements;
630 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
631 memoryAccess, alignment);
635 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
637 Value
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
638 loc, spvLoadOp.
getType(), spvLoadOp, offset);
641 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
642 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
644 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
649 IntegerAttr shiftValueAttr =
650 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
652 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
653 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
655 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
658 rewriter.replaceOp(loadOp,
result);
660 assert(accessChainOp.use_empty());
661 rewriter.eraseOp(accessChainOp);
667LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter)
const {
669 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
670 if (memrefType.getElementType().isSignlessInteger())
673 auto memorySpaceAttr =
674 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
675 if (!memorySpaceAttr)
676 return rewriter.notifyMatchFailure(
677 loadOp,
"missing memory space SPIR-V storage class attribute");
679 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
680 return rewriter.notifyMatchFailure(
682 "failed to lower memref in image storage class to storage buffer");
685 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
686 adaptor.getIndices(), loadOp.getLoc(), rewriter);
692 if (
failed(memoryRequirements))
693 return rewriter.notifyMatchFailure(
694 loadOp,
"failed to determine memory requirements");
696 auto [memoryAccess, alignment] = *memoryRequirements;
697 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
702template <
typename OpAdaptor>
703static FailureOr<SmallVector<Value>>
705 ConversionPatternRewriter &rewriter) {
712 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
714 return rewriter.notifyMatchFailure(
716 "Cannot lower memrefs with memory layout which is not a permutation");
722 for (
unsigned dim = 0; dim < dimCount; ++dim)
728 return llvm::to_vector(llvm::reverse(coords));
732ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter)
const {
734 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
736 auto memorySpaceAttr =
737 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
738 if (!memorySpaceAttr)
739 return rewriter.notifyMatchFailure(
740 loadOp,
"missing memory space SPIR-V storage class attribute");
742 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
743 return rewriter.notifyMatchFailure(
744 loadOp,
"failed to lower memref in non-image storage class to image");
746 Value loadPtr = adaptor.getMemref();
748 if (
failed(memoryRequirements))
749 return rewriter.notifyMatchFailure(
750 loadOp,
"failed to determine memory requirements");
752 const auto [memoryAccess, alignment] = *memoryRequirements;
754 if (!loadOp.getMemRefType().hasRank())
755 return rewriter.notifyMatchFailure(
756 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
761 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
762 return rewriter.notifyMatchFailure(
764 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
771 auto convertedPointeeType = cast<spirv::PointerType>(
772 getTypeConverter()->convertType(loadOp.getMemRefType()));
773 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
774 return rewriter.notifyMatchFailure(loadOp,
775 "cannot lower memrefs which do not "
776 "convert to SPIR-V sampled images");
779 Location loc = loadOp->getLoc();
781 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
783 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
787 if (memrefType.getRank() == 1) {
788 coords = adaptor.getIndices()[0];
790 FailureOr<SmallVector<Value>> maybeCoords =
794 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
795 adaptor.getIndices().
getType()[0]);
796 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
797 maybeCoords.value());
801 auto resultVectorType = VectorType::get({4}, loadOp.getType());
802 auto fetchOp = spirv::ImageFetchOp::create(
803 rewriter, loc, resultVectorType, imageOp, coords,
804 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
809 auto compositeExtractOp =
810 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
812 rewriter.replaceOp(loadOp, compositeExtractOp);
817IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter)
const {
819 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
820 if (!memrefType.getElementType().isSignlessInteger())
821 return rewriter.notifyMatchFailure(storeOp,
822 "element type is not a signless int");
824 auto loc = storeOp.getLoc();
825 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
828 adaptor.getIndices(), loc, rewriter);
831 return rewriter.notifyMatchFailure(
832 storeOp,
"failed to convert element pointer type");
834 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
836 bool isBool = srcBits == 1;
838 srcBits = typeConverter.getOptions().boolNumBits;
840 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
842 return rewriter.notifyMatchFailure(storeOp,
843 "failed to convert memref type");
845 Type pointeeType = pointerType.getPointeeType();
847 if (typeConverter.allows(spirv::Capability::Kernel)) {
848 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
849 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
851 dstType = dyn_cast<IntegerType>(pointeeType);
854 Type structElemType =
855 cast<spirv::StructType>(pointeeType).getElementType(0);
856 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
857 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
859 dstType = dyn_cast<IntegerType>(
864 return rewriter.notifyMatchFailure(
865 storeOp,
"failed to determine destination element type");
867 int dstBits =
static_cast<int>(dstType.getWidth());
868 assert(dstBits % srcBits == 0);
870 if (srcBits == dstBits) {
872 if (
failed(memoryRequirements))
873 return rewriter.notifyMatchFailure(
874 storeOp,
"failed to determine memory requirements");
876 auto [memoryAccess, alignment] = *memoryRequirements;
877 Value storeVal = adaptor.getValue();
880 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
881 memoryAccess, alignment);
887 if (typeConverter.allows(spirv::Capability::Kernel))
890 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
905 assert(accessChainOp.getIndices().size() == 2);
906 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
911 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
912 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
913 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
914 loc, dstType, mask, offset);
916 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
918 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
920 srcBits, dstBits, rewriter);
923 return rewriter.notifyMatchFailure(storeOp,
"atomic scope not available");
925 Value
result = spirv::AtomicAndOp::create(
926 rewriter, loc, dstType, adjustedPtr, *scope,
927 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
928 result = spirv::AtomicOrOp::create(
929 rewriter, loc, dstType, adjustedPtr, *scope,
930 spirv::MemorySemantics::AcquireRelease, storeVal);
936 rewriter.eraseOp(storeOp);
938 assert(accessChainOp.use_empty());
939 rewriter.eraseOp(accessChainOp);
948LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
949 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter)
const {
951 Location loc = addrCastOp.getLoc();
952 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
953 if (!typeConverter.allows(spirv::Capability::Kernel))
954 return rewriter.notifyMatchFailure(
955 loc,
"address space casts require kernel capability");
957 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
959 return rewriter.notifyMatchFailure(
960 loc,
"SPIR-V lowering requires ranked memref types");
961 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
963 auto sourceStorageClassAttr =
964 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
965 if (!sourceStorageClassAttr)
966 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &
diag) {
967 diag <<
"source address space " << sourceType.getMemorySpace()
968 <<
" must be a SPIR-V storage class";
970 auto resultStorageClassAttr =
971 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
972 if (!resultStorageClassAttr)
973 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &
diag) {
974 diag <<
"result address space " << resultType.getMemorySpace()
975 <<
" must be a SPIR-V storage class";
978 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
979 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
981 Value
result = adaptor.getSource();
982 Type resultPtrType = typeConverter.convertType(resultType);
984 return rewriter.notifyMatchFailure(addrCastOp,
985 "failed to convert memref type");
987 Type genericPtrType = resultPtrType;
995 if (sourceSc != spirv::StorageClass::Generic &&
996 resultSc != spirv::StorageClass::Generic) {
997 Type intermediateType =
998 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
999 sourceType.getLayout(),
1000 rewriter.getAttr<spirv::StorageClassAttr>(
1001 spirv::StorageClass::Generic));
1002 genericPtrType = typeConverter.convertType(intermediateType);
1004 if (sourceSc != spirv::StorageClass::Generic) {
1005 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1008 if (resultSc != spirv::StorageClass::Generic) {
1010 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType,
result);
1012 rewriter.replaceOp(addrCastOp,
result);
1017StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1018 ConversionPatternRewriter &rewriter)
const {
1019 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1020 if (memrefType.getElementType().isSignlessInteger())
1021 return rewriter.notifyMatchFailure(storeOp,
"signless int");
1023 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1024 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1027 return rewriter.notifyMatchFailure(storeOp,
"type conversion failed");
1030 if (
failed(memoryRequirements))
1031 return rewriter.notifyMatchFailure(
1032 storeOp,
"failed to determine memory requirements");
1034 auto [memoryAccess, alignment] = *memoryRequirements;
1035 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1036 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1040LogicalResult ReinterpretCastPattern::matchAndRewrite(
1041 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter)
const {
1043 Value src = adaptor.getSource();
1044 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1047 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1051 const TypeConverter *converter = getTypeConverter();
1053 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1054 if (dstType != srcType)
1055 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1056 diag <<
"invalid dst type " << op.getType();
1059 OpFoldResult offset =
1060 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1063 rewriter.replaceOp(op, src);
1067 Type intType = converter->convertType(rewriter.getIndexType());
1069 return rewriter.notifyMatchFailure(op,
"failed to convert index type");
1071 Location loc = op.getLoc();
1072 auto offsetValue = [&]() -> Value {
1073 if (
auto val = dyn_cast<Value>(offset))
1076 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1077 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1078 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1081 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1090LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1091 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter)
const {
1093 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1094 Type indexType = typeConverter.getIndexType();
1095 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1096 adaptor.getSource());
1107 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1108 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1109 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1110 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1111 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static Type getElementType(Type type)
Determine the element type of type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess