28#define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (
auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
171 auto one = spirv::ConstantOp::getZero(srcInt.
getType(), loc, builder);
172 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
186class AllocaOpPattern final :
public OpConversionPattern<memref::AllocaOp> {
191 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const override;
199class AllocOpPattern final :
public OpConversionPattern<memref::AllocOp> {
204 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override;
209class AtomicRMWOpPattern final
210 :
public OpConversionPattern<memref::AtomicRMWOp> {
215 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override;
221class DeallocOpPattern final :
public OpConversionPattern<memref::DeallocOp> {
226 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter)
const override;
231class IntLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
236 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter)
const override;
241class LoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
246 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter)
const override;
251class ImageLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
256 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
257 ConversionPatternRewriter &rewriter)
const override;
261class IntStoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
266 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter)
const override;
271class MemorySpaceCastOpPattern final
272 :
public OpConversionPattern<memref::MemorySpaceCastOp> {
277 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override;
282class StoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
287 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override;
291class ReinterpretCastPattern final
292 :
public OpConversionPattern<memref::ReinterpretCastOp> {
297 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter)
const override;
301class CastPattern final :
public OpConversionPattern<memref::CastOp> {
306 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
307 ConversionPatternRewriter &rewriter)
const override {
308 Value src = adaptor.getSource();
311 const TypeConverter *converter = getTypeConverter();
312 Type dstType = converter->convertType(op.getType());
313 if (srcType != dstType)
314 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
315 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
318 rewriter.replaceOp(op, src);
324class ExtractAlignedPointerAsIndexOpPattern final
325 :
public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
330 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
332 ConversionPatternRewriter &rewriter)
const override;
341AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const {
343 MemRefType allocType = allocaOp.getType();
345 return rewriter.notifyMatchFailure(allocaOp,
"unhandled allocation type");
348 Type spirvType = getTypeConverter()->convertType(allocType);
350 return rewriter.notifyMatchFailure(allocaOp,
"type conversion failed");
352 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
353 spirv::StorageClass::Function,
363AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter)
const {
365 MemRefType allocType = operation.getType();
367 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
370 Type spirvType = getTypeConverter()->convertType(allocType);
372 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
379 Location loc = operation.getLoc();
380 spirv::GlobalVariableOp varOp;
382 OpBuilder::InsertionGuard guard(rewriter);
384 rewriter.setInsertionPointToStart(&entryBlock);
385 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
386 std::string varName =
387 std::string(
"__workgroup_mem__") +
388 std::to_string(std::distance(varOps.begin(), varOps.end()));
389 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
394 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
403AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
405 ConversionPatternRewriter &rewriter)
const {
406 if (isa<FloatType>(atomicOp.getType()))
407 return rewriter.notifyMatchFailure(atomicOp,
408 "unimplemented floating-point case");
410 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
413 return rewriter.notifyMatchFailure(atomicOp,
414 "unsupported memref memory space");
416 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
417 Type resultType = typeConverter.convertType(atomicOp.getType());
419 return rewriter.notifyMatchFailure(atomicOp,
420 "failed to convert result type");
422 auto loc = atomicOp.getLoc();
425 adaptor.getIndices(), loc, rewriter);
430#define ATOMIC_CASE(kind, spirvOp) \
431 case arith::AtomicRMWKind::kind: \
432 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
433 atomicOp, resultType, ptr, *scope, \
434 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
437 switch (atomicOp.getKind()) {
446 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
459DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
461 ConversionPatternRewriter &rewriter)
const {
462 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
464 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
465 rewriter.eraseOp(operation);
480static FailureOr<MemoryRequirements>
482 uint64_t preferredAlignment) {
483 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
489 auto memoryAccess = spirv::MemoryAccess::None;
491 memoryAccess = spirv::MemoryAccess::Nontemporal;
494 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
495 bool mayOmitAlignment =
496 !preferredAlignment &&
497 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
498 if (mayOmitAlignment) {
499 if (memoryAccess == spirv::MemoryAccess::None) {
508 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
513 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
514 if (!sizeInBytes.has_value())
517 memoryAccess |= spirv::MemoryAccess::Aligned;
518 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
519 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
520 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
527template <
class LoadOrStoreOp>
528static FailureOr<MemoryRequirements>
531 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
532 "Must be called on either memref::LoadOp or memref::StoreOp");
535 loadOrStoreOp.getNontemporal(),
536 loadOrStoreOp.getAlignment().value_or(0));
540IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
541 ConversionPatternRewriter &rewriter)
const {
542 auto loc = loadOp.getLoc();
543 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
544 if (!memrefType.getElementType().isSignlessInteger())
547 auto memorySpaceAttr =
548 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
549 if (!memorySpaceAttr)
550 return rewriter.notifyMatchFailure(
551 loadOp,
"missing memory space SPIR-V storage class attribute");
553 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
554 return rewriter.notifyMatchFailure(
556 "failed to lower memref in image storage class to storage buffer");
558 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
561 adaptor.getIndices(), loc, rewriter);
566 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
567 bool isBool = srcBits == 1;
569 srcBits = typeConverter.getOptions().boolNumBits;
571 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
573 return rewriter.notifyMatchFailure(loadOp,
"failed to convert memref type");
575 Type pointeeType = pointerType.getPointeeType();
577 if (typeConverter.allows(spirv::Capability::Kernel)) {
578 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
579 dstType = arrayType.getElementType();
581 dstType = pointeeType;
584 Type structElemType =
585 cast<spirv::StructType>(pointeeType).getElementType(0);
586 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
587 dstType = arrayType.getElementType();
589 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
592 assert(dstBits % srcBits == 0);
596 if (srcBits == dstBits) {
598 if (
failed(memoryRequirements))
599 return rewriter.notifyMatchFailure(
600 loadOp,
"failed to determine memory requirements");
602 auto [memoryAccess, alignment] = *memoryRequirements;
603 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
604 memoryAccess, alignment);
607 rewriter.replaceOp(loadOp, loadVal);
613 if (typeConverter.allows(spirv::Capability::Kernel))
616 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
623 assert(accessChainOp.getIndices().size() == 2);
625 srcBits, dstBits, rewriter);
627 if (
failed(memoryRequirements))
628 return rewriter.notifyMatchFailure(
629 loadOp,
"failed to determine memory requirements");
631 auto [memoryAccess, alignment] = *memoryRequirements;
632 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
633 memoryAccess, alignment);
637 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
639 Value
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
640 loc, spvLoadOp.
getType(), spvLoadOp, offset);
643 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
644 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
646 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
651 IntegerAttr shiftValueAttr =
652 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
654 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
655 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
657 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
660 rewriter.replaceOp(loadOp,
result);
662 assert(accessChainOp.use_empty());
663 rewriter.eraseOp(accessChainOp);
669LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter)
const {
671 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
672 if (memrefType.getElementType().isSignlessInteger())
675 auto memorySpaceAttr =
676 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
677 if (!memorySpaceAttr)
678 return rewriter.notifyMatchFailure(
679 loadOp,
"missing memory space SPIR-V storage class attribute");
681 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
682 return rewriter.notifyMatchFailure(
684 "failed to lower memref in image storage class to storage buffer");
687 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
688 adaptor.getIndices(), loadOp.getLoc(), rewriter);
694 if (
failed(memoryRequirements))
695 return rewriter.notifyMatchFailure(
696 loadOp,
"failed to determine memory requirements");
698 auto [memoryAccess, alignment] = *memoryRequirements;
699 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
704template <
typename OpAdaptor>
705static FailureOr<SmallVector<Value>>
707 ConversionPatternRewriter &rewriter) {
714 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
716 return rewriter.notifyMatchFailure(
718 "Cannot lower memrefs with memory layout which is not a permutation");
724 for (
unsigned dim = 0; dim < dimCount; ++dim)
730 return llvm::to_vector(llvm::reverse(coords));
734ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
735 ConversionPatternRewriter &rewriter)
const {
736 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
738 auto memorySpaceAttr =
739 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
740 if (!memorySpaceAttr)
741 return rewriter.notifyMatchFailure(
742 loadOp,
"missing memory space SPIR-V storage class attribute");
744 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
745 return rewriter.notifyMatchFailure(
746 loadOp,
"failed to lower memref in non-image storage class to image");
748 Value loadPtr = adaptor.getMemref();
750 if (
failed(memoryRequirements))
751 return rewriter.notifyMatchFailure(
752 loadOp,
"failed to determine memory requirements");
754 const auto [memoryAccess, alignment] = *memoryRequirements;
756 if (!loadOp.getMemRefType().hasRank())
757 return rewriter.notifyMatchFailure(
758 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
763 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
764 return rewriter.notifyMatchFailure(
766 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
773 auto convertedPointeeType = cast<spirv::PointerType>(
774 getTypeConverter()->convertType(loadOp.getMemRefType()));
775 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
776 return rewriter.notifyMatchFailure(loadOp,
777 "cannot lower memrefs which do not "
778 "convert to SPIR-V sampled images");
781 Location loc = loadOp->getLoc();
783 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
785 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
789 if (memrefType.getRank() == 1) {
790 coords = adaptor.getIndices()[0];
792 FailureOr<SmallVector<Value>> maybeCoords =
796 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
797 adaptor.getIndices().
getType()[0]);
798 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
799 maybeCoords.value());
803 auto resultVectorType = VectorType::get({4}, loadOp.getType());
804 auto fetchOp = spirv::ImageFetchOp::create(
805 rewriter, loc, resultVectorType, imageOp, coords,
806 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
811 auto compositeExtractOp =
812 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
814 rewriter.replaceOp(loadOp, compositeExtractOp);
819IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
820 ConversionPatternRewriter &rewriter)
const {
821 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
822 if (!memrefType.getElementType().isSignlessInteger())
823 return rewriter.notifyMatchFailure(storeOp,
824 "element type is not a signless int");
826 auto loc = storeOp.getLoc();
827 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
830 adaptor.getIndices(), loc, rewriter);
833 return rewriter.notifyMatchFailure(
834 storeOp,
"failed to convert element pointer type");
836 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
838 bool isBool = srcBits == 1;
840 srcBits = typeConverter.getOptions().boolNumBits;
842 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
844 return rewriter.notifyMatchFailure(storeOp,
845 "failed to convert memref type");
847 Type pointeeType = pointerType.getPointeeType();
849 if (typeConverter.allows(spirv::Capability::Kernel)) {
850 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
851 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
853 dstType = dyn_cast<IntegerType>(pointeeType);
856 Type structElemType =
857 cast<spirv::StructType>(pointeeType).getElementType(0);
858 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
859 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
861 dstType = dyn_cast<IntegerType>(
866 return rewriter.notifyMatchFailure(
867 storeOp,
"failed to determine destination element type");
869 int dstBits =
static_cast<int>(dstType.getWidth());
870 assert(dstBits % srcBits == 0);
872 if (srcBits == dstBits) {
874 if (
failed(memoryRequirements))
875 return rewriter.notifyMatchFailure(
876 storeOp,
"failed to determine memory requirements");
878 auto [memoryAccess, alignment] = *memoryRequirements;
879 Value storeVal = adaptor.getValue();
882 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
883 memoryAccess, alignment);
889 if (typeConverter.allows(spirv::Capability::Kernel))
892 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
907 assert(accessChainOp.getIndices().size() == 2);
908 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
913 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
914 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
915 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
916 loc, dstType, mask, offset);
918 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
920 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
922 srcBits, dstBits, rewriter);
925 return rewriter.notifyMatchFailure(storeOp,
"atomic scope not available");
927 Value
result = spirv::AtomicAndOp::create(
928 rewriter, loc, dstType, adjustedPtr, *scope,
929 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
930 result = spirv::AtomicOrOp::create(
931 rewriter, loc, dstType, adjustedPtr, *scope,
932 spirv::MemorySemantics::AcquireRelease, storeVal);
938 rewriter.eraseOp(storeOp);
940 assert(accessChainOp.use_empty());
941 rewriter.eraseOp(accessChainOp);
950LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
951 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter)
const {
953 Location loc = addrCastOp.getLoc();
954 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
955 if (!typeConverter.allows(spirv::Capability::Kernel))
956 return rewriter.notifyMatchFailure(
957 loc,
"address space casts require kernel capability");
959 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
961 return rewriter.notifyMatchFailure(
962 loc,
"SPIR-V lowering requires ranked memref types");
963 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
965 auto sourceStorageClassAttr =
966 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
967 if (!sourceStorageClassAttr)
968 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &
diag) {
969 diag <<
"source address space " << sourceType.getMemorySpace()
970 <<
" must be a SPIR-V storage class";
972 auto resultStorageClassAttr =
973 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
974 if (!resultStorageClassAttr)
975 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &
diag) {
976 diag <<
"result address space " << resultType.getMemorySpace()
977 <<
" must be a SPIR-V storage class";
980 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
981 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
983 Value
result = adaptor.getSource();
984 Type resultPtrType = typeConverter.convertType(resultType);
986 return rewriter.notifyMatchFailure(addrCastOp,
987 "failed to convert memref type");
989 Type genericPtrType = resultPtrType;
997 if (sourceSc != spirv::StorageClass::Generic &&
998 resultSc != spirv::StorageClass::Generic) {
999 Type intermediateType =
1000 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1001 sourceType.getLayout(),
1002 rewriter.getAttr<spirv::StorageClassAttr>(
1003 spirv::StorageClass::Generic));
1004 genericPtrType = typeConverter.convertType(intermediateType);
1006 if (sourceSc != spirv::StorageClass::Generic) {
1007 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1010 if (resultSc != spirv::StorageClass::Generic) {
1012 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType,
result);
1014 rewriter.replaceOp(addrCastOp,
result);
1019StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1020 ConversionPatternRewriter &rewriter)
const {
1021 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1022 if (memrefType.getElementType().isSignlessInteger())
1023 return rewriter.notifyMatchFailure(storeOp,
"signless int");
1025 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1026 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1029 return rewriter.notifyMatchFailure(storeOp,
"type conversion failed");
1032 if (
failed(memoryRequirements))
1033 return rewriter.notifyMatchFailure(
1034 storeOp,
"failed to determine memory requirements");
1036 auto [memoryAccess, alignment] = *memoryRequirements;
1037 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1038 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1042LogicalResult ReinterpretCastPattern::matchAndRewrite(
1043 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1044 ConversionPatternRewriter &rewriter)
const {
1045 Value src = adaptor.getSource();
1046 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1049 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1053 const TypeConverter *converter = getTypeConverter();
1055 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1056 if (dstType != srcType)
1057 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1058 diag <<
"invalid dst type " << op.getType();
1061 OpFoldResult offset =
1062 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1065 rewriter.replaceOp(op, src);
1069 Type intType = converter->convertType(rewriter.getIndexType());
1071 return rewriter.notifyMatchFailure(op,
"failed to convert index type");
1073 Location loc = op.getLoc();
1074 auto offsetValue = [&]() -> Value {
1075 if (
auto val = dyn_cast<Value>(offset))
1078 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1079 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1080 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1083 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1092LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1093 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1094 ConversionPatternRewriter &rewriter)
const {
1095 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1096 Type indexType = typeConverter.getIndexType();
1097 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1098 adaptor.getSource());
1109 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1110 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1111 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1112 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1113 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static Type getElementType(Type type)
Determine the element type of type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess