28#define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (
auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
169static spirv::MemorySemantics
172 case spirv::StorageClass::StorageBuffer:
173 case spirv::StorageClass::Uniform:
174 return spirv::MemorySemantics::UniformMemory;
175 case spirv::StorageClass::Workgroup:
176 return spirv::MemorySemantics::WorkgroupMemory;
177 case spirv::StorageClass::CrossWorkgroup:
178 return spirv::MemorySemantics::CrossWorkgroupMemory;
179 case spirv::StorageClass::AtomicCounter:
180 return spirv::MemorySemantics::AtomicCounterMemory;
181 case spirv::StorageClass::Image:
182 return spirv::MemorySemantics::ImageMemory;
184 return spirv::MemorySemantics::None;
191 auto sc = cast<spirv::StorageClassAttr>(type.getMemorySpace()).getValue();
192 return spirv::MemorySemantics::AcquireRelease |
205 if (typeConverter.
allows(spirv::Capability::Kernel)) {
206 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
207 return arrayType.getElementType();
211 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
212 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
213 return arrayType.getElementType();
214 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
222 auto one = spirv::ConstantOp::getZero(srcInt.
getType(), loc, builder);
223 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
237class AllocaOpPattern final :
public OpConversionPattern<memref::AllocaOp> {
242 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter)
const override;
250class AllocOpPattern final :
public OpConversionPattern<memref::AllocOp> {
255 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter)
const override;
260class AtomicRMWOpPattern final
261 :
public OpConversionPattern<memref::AtomicRMWOp> {
266 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter)
const override;
272class DeallocOpPattern final :
public OpConversionPattern<memref::DeallocOp> {
277 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override;
282class IntLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
287 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override;
292class LoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
297 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter)
const override;
302class ImageLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
307 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter)
const override;
312class IntStoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
317 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter)
const override;
322class MemorySpaceCastOpPattern final
323 :
public OpConversionPattern<memref::MemorySpaceCastOp> {
328 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter)
const override;
333class StoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
338 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override;
342class ReinterpretCastPattern final
343 :
public OpConversionPattern<memref::ReinterpretCastOp> {
348 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
352class CastPattern final :
public OpConversionPattern<memref::CastOp> {
357 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
359 Value src = adaptor.getSource();
362 const TypeConverter *converter = getTypeConverter();
363 Type dstType = converter->convertType(op.getType());
364 if (srcType != dstType)
365 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
366 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
369 rewriter.replaceOp(op, src);
375class ExtractAlignedPointerAsIndexOpPattern final
376 :
public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
381 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
383 ConversionPatternRewriter &rewriter)
const override;
392AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter)
const {
394 MemRefType allocType = allocaOp.getType();
396 return rewriter.notifyMatchFailure(allocaOp,
"unhandled allocation type");
399 Type spirvType = getTypeConverter()->convertType(allocType);
401 return rewriter.notifyMatchFailure(allocaOp,
"type conversion failed");
403 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
404 spirv::StorageClass::Function,
414AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter)
const {
416 MemRefType allocType = operation.getType();
418 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
421 Type spirvType = getTypeConverter()->convertType(allocType);
423 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
430 Location loc = operation.getLoc();
431 spirv::GlobalVariableOp varOp;
433 OpBuilder::InsertionGuard guard(rewriter);
435 rewriter.setInsertionPointToStart(&entryBlock);
436 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
437 std::string varName =
438 std::string(
"__workgroup_mem__") +
439 std::to_string(std::distance(varOps.begin(), varOps.end()));
440 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
445 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
454AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
456 ConversionPatternRewriter &rewriter)
const {
457 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
460 return rewriter.notifyMatchFailure(atomicOp,
461 "unsupported memref memory space");
463 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
464 Type resultType = typeConverter.convertType(atomicOp.getType());
466 return rewriter.notifyMatchFailure(atomicOp,
467 "failed to convert result type");
469 auto loc = atomicOp.getLoc();
472 adaptor.getIndices(), loc, rewriter);
480 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
481 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
483 return rewriter.notifyMatchFailure(atomicOp,
484 "failed to convert memref type");
486 Type pointeeType = pointerType.getPointeeType();
487 Type storageElemType =
489 if (!storageElemType || !storageElemType.
isIntOrFloat())
490 return rewriter.notifyMatchFailure(
491 atomicOp,
"failed to determine destination element type");
494 assert(dstBits % srcBits == 0);
500 if (srcBits == dstBits) {
501#define ATOMIC_CASE(kind, spirvOp) \
502 case arith::AtomicRMWKind::kind: \
503 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
504 atomicOp, resultType, ptr, *scope, memSem, adaptor.getValue()); \
507 switch (atomicOp.getKind()) {
517 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
532 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
533 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
534 return rewriter.notifyMatchFailure(
536 "atomic op on sub-element-width types is only supported for ori/andi");
541 if (typeConverter.allows(spirv::Capability::Kernel))
542 return rewriter.notifyMatchFailure(
544 "sub-element-width atomic ops unsupported with Kernel capability");
546 auto dstType = cast<IntegerType>(storageElemType);
548 auto accessChainOp = ptr.
getDefiningOp<spirv::AccessChainOp>();
554 assert(accessChainOp.getIndices().size() == 2);
555 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
558 srcBits, dstBits, rewriter);
560 switch (atomicOp.getKind()) {
561 case arith::AtomicRMWKind::ori: {
564 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
565 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
567 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
568 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
569 *scope, memSem, storeVal);
572 case arith::AtomicRMWKind::andi: {
576 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
577 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
579 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
580 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
581 loc, dstType, elemMask, offset);
582 Value invertedElemMask =
583 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
584 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
586 result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
587 *scope, memSem, mask);
591 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
596 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
598 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
599 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
601 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
602 rewriter.replaceOp(atomicOp,
result);
612DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
614 ConversionPatternRewriter &rewriter)
const {
615 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
617 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
618 rewriter.eraseOp(operation);
633static FailureOr<MemoryRequirements>
635 uint64_t preferredAlignment) {
636 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
642 auto memoryAccess = spirv::MemoryAccess::None;
644 memoryAccess = spirv::MemoryAccess::Nontemporal;
647 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
648 bool mayOmitAlignment =
649 !preferredAlignment &&
650 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
651 if (mayOmitAlignment) {
652 if (memoryAccess == spirv::MemoryAccess::None) {
661 std::optional<int64_t> sizeInBytes;
662 Type rawPointeeType = ptrType.getPointeeType();
663 if (
auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
665 sizeInBytes = scalarType.getSizeInBytes();
666 }
else if (
auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
669 if (
auto scalarElem =
670 dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
671 if (
auto elemSize = scalarElem.getSizeInBytes())
672 sizeInBytes = *elemSize * vecType.getNumElements();
676 if (!sizeInBytes.has_value())
679 memoryAccess |= spirv::MemoryAccess::Aligned;
680 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
681 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
682 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
689template <
class LoadOrStoreOp>
690static FailureOr<MemoryRequirements>
693 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
694 "Must be called on either memref::LoadOp or memref::StoreOp");
697 loadOrStoreOp.getNontemporal(),
698 loadOrStoreOp.getAlignment().value_or(0));
702IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
703 ConversionPatternRewriter &rewriter)
const {
704 auto loc = loadOp.getLoc();
705 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
706 if (!memrefType.getElementType().isSignlessInteger())
709 auto memorySpaceAttr =
710 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
711 if (!memorySpaceAttr)
712 return rewriter.notifyMatchFailure(
713 loadOp,
"missing memory space SPIR-V storage class attribute");
715 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
716 return rewriter.notifyMatchFailure(
718 "failed to lower memref in image storage class to storage buffer");
720 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
723 adaptor.getIndices(), loc, rewriter);
728 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
729 bool isBool = srcBits == 1;
731 srcBits = typeConverter.getOptions().boolNumBits;
733 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
735 return rewriter.notifyMatchFailure(loadOp,
"failed to convert memref type");
737 Type pointeeType = pointerType.getPointeeType();
740 assert(dstBits % srcBits == 0);
744 if (srcBits == dstBits) {
746 if (
failed(memoryRequirements))
747 return rewriter.notifyMatchFailure(
748 loadOp,
"failed to determine memory requirements");
750 auto [memoryAccess, alignment] = *memoryRequirements;
751 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
752 memoryAccess, alignment);
755 rewriter.replaceOp(loadOp, loadVal);
761 if (typeConverter.allows(spirv::Capability::Kernel))
764 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
771 assert(accessChainOp.getIndices().size() == 2);
773 srcBits, dstBits, rewriter);
775 if (
failed(memoryRequirements))
776 return rewriter.notifyMatchFailure(
777 loadOp,
"failed to determine memory requirements");
779 auto [memoryAccess, alignment] = *memoryRequirements;
780 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
781 memoryAccess, alignment);
785 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
787 Value
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
788 loc, spvLoadOp.
getType(), spvLoadOp, offset);
791 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
792 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
794 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
799 IntegerAttr shiftValueAttr =
800 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
802 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
803 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
805 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
808 rewriter.replaceOp(loadOp,
result);
810 assert(accessChainOp.use_empty());
811 rewriter.eraseOp(accessChainOp);
817LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter)
const {
819 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
820 if (memrefType.getElementType().isSignlessInteger())
823 auto memorySpaceAttr =
824 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
825 if (!memorySpaceAttr)
826 return rewriter.notifyMatchFailure(
827 loadOp,
"missing memory space SPIR-V storage class attribute");
829 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
830 return rewriter.notifyMatchFailure(
832 "failed to lower memref in image storage class to storage buffer");
835 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
836 adaptor.getIndices(), loadOp.getLoc(), rewriter);
842 if (
failed(memoryRequirements))
843 return rewriter.notifyMatchFailure(
844 loadOp,
"failed to determine memory requirements");
846 auto [memoryAccess, alignment] = *memoryRequirements;
847 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
852template <
typename OpAdaptor>
853static FailureOr<SmallVector<Value>>
855 ConversionPatternRewriter &rewriter) {
862 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
864 return rewriter.notifyMatchFailure(
866 "Cannot lower memrefs with memory layout which is not a permutation");
872 for (
unsigned dim = 0; dim < dimCount; ++dim)
878 return llvm::to_vector(llvm::reverse(coords));
882ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
883 ConversionPatternRewriter &rewriter)
const {
884 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
886 auto memorySpaceAttr =
887 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
888 if (!memorySpaceAttr)
889 return rewriter.notifyMatchFailure(
890 loadOp,
"missing memory space SPIR-V storage class attribute");
892 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
893 return rewriter.notifyMatchFailure(
894 loadOp,
"failed to lower memref in non-image storage class to image");
896 Value loadPtr = adaptor.getMemref();
898 if (
failed(memoryRequirements))
899 return rewriter.notifyMatchFailure(
900 loadOp,
"failed to determine memory requirements");
902 const auto [memoryAccess, alignment] = *memoryRequirements;
904 if (!loadOp.getMemRefType().hasRank())
905 return rewriter.notifyMatchFailure(
906 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
911 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
912 return rewriter.notifyMatchFailure(
914 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
921 auto convertedPointeeType = cast<spirv::PointerType>(
922 getTypeConverter()->convertType(loadOp.getMemRefType()));
923 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
924 return rewriter.notifyMatchFailure(loadOp,
925 "cannot lower memrefs which do not "
926 "convert to SPIR-V sampled images");
929 Location loc = loadOp->getLoc();
931 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
933 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
937 if (memrefType.getRank() == 1) {
938 coords = adaptor.getIndices()[0];
940 FailureOr<SmallVector<Value>> maybeCoords =
944 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
945 adaptor.getIndices().
getType()[0]);
946 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
947 maybeCoords.value());
951 auto resultVectorType = VectorType::get({4}, loadOp.getType());
952 auto fetchOp = spirv::ImageFetchOp::create(
953 rewriter, loc, resultVectorType, imageOp, coords,
954 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
959 auto compositeExtractOp =
960 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
962 rewriter.replaceOp(loadOp, compositeExtractOp);
967IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
968 ConversionPatternRewriter &rewriter)
const {
969 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
970 if (!memrefType.getElementType().isSignlessInteger())
971 return rewriter.notifyMatchFailure(storeOp,
972 "element type is not a signless int");
974 auto loc = storeOp.getLoc();
975 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
978 adaptor.getIndices(), loc, rewriter);
981 return rewriter.notifyMatchFailure(
982 storeOp,
"failed to convert element pointer type");
984 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
986 bool isBool = srcBits == 1;
988 srcBits = typeConverter.getOptions().boolNumBits;
990 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
992 return rewriter.notifyMatchFailure(storeOp,
993 "failed to convert memref type");
995 Type pointeeType = pointerType.getPointeeType();
996 auto dstType = dyn_cast<IntegerType>(
999 return rewriter.notifyMatchFailure(
1000 storeOp,
"failed to determine destination element type");
1002 int dstBits =
static_cast<int>(dstType.getWidth());
1003 assert(dstBits % srcBits == 0);
1005 if (srcBits == dstBits) {
1007 if (
failed(memoryRequirements))
1008 return rewriter.notifyMatchFailure(
1009 storeOp,
"failed to determine memory requirements");
1011 auto [memoryAccess, alignment] = *memoryRequirements;
1012 Value storeVal = adaptor.getValue();
1015 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
1016 memoryAccess, alignment);
1022 if (typeConverter.allows(spirv::Capability::Kernel))
1025 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
1040 assert(accessChainOp.getIndices().size() == 2);
1041 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1046 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1047 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1048 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1049 loc, dstType, mask, offset);
1051 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1053 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1055 srcBits, dstBits, rewriter);
1058 return rewriter.notifyMatchFailure(storeOp,
"atomic scope not available");
1061 Value
result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
1062 *scope, memSem, clearBitsMask);
1063 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
1064 *scope, memSem, storeVal);
1070 rewriter.eraseOp(storeOp);
1072 assert(accessChainOp.use_empty());
1073 rewriter.eraseOp(accessChainOp);
1082LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1083 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1084 ConversionPatternRewriter &rewriter)
const {
1085 Location loc = addrCastOp.getLoc();
1086 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1087 if (!typeConverter.allows(spirv::Capability::Kernel))
1088 return rewriter.notifyMatchFailure(
1089 loc,
"address space casts require kernel capability");
1091 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1093 return rewriter.notifyMatchFailure(
1094 loc,
"SPIR-V lowering requires ranked memref types");
1095 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1097 auto sourceStorageClassAttr =
1098 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1099 if (!sourceStorageClassAttr)
1100 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &
diag) {
1101 diag <<
"source address space " << sourceType.getMemorySpace()
1102 <<
" must be a SPIR-V storage class";
1104 auto resultStorageClassAttr =
1105 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1106 if (!resultStorageClassAttr)
1107 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &
diag) {
1108 diag <<
"result address space " << resultType.getMemorySpace()
1109 <<
" must be a SPIR-V storage class";
1112 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1113 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1115 Value
result = adaptor.getSource();
1116 Type resultPtrType = typeConverter.convertType(resultType);
1118 return rewriter.notifyMatchFailure(addrCastOp,
1119 "failed to convert memref type");
1121 Type genericPtrType = resultPtrType;
1129 if (sourceSc != spirv::StorageClass::Generic &&
1130 resultSc != spirv::StorageClass::Generic) {
1131 Type intermediateType =
1132 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1133 sourceType.getLayout(),
1134 rewriter.getAttr<spirv::StorageClassAttr>(
1135 spirv::StorageClass::Generic));
1136 genericPtrType = typeConverter.convertType(intermediateType);
1138 if (sourceSc != spirv::StorageClass::Generic) {
1139 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1142 if (resultSc != spirv::StorageClass::Generic) {
1144 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType,
result);
1146 rewriter.replaceOp(addrCastOp,
result);
1151StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1152 ConversionPatternRewriter &rewriter)
const {
1153 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1154 if (memrefType.getElementType().isSignlessInteger())
1155 return rewriter.notifyMatchFailure(storeOp,
"signless int");
1157 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1158 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1161 return rewriter.notifyMatchFailure(storeOp,
"type conversion failed");
1164 if (
failed(memoryRequirements))
1165 return rewriter.notifyMatchFailure(
1166 storeOp,
"failed to determine memory requirements");
1168 auto [memoryAccess, alignment] = *memoryRequirements;
1169 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1170 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1174LogicalResult ReinterpretCastPattern::matchAndRewrite(
1175 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter)
const {
1177 Value src = adaptor.getSource();
1178 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1181 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1185 const TypeConverter *converter = getTypeConverter();
1187 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1188 if (dstType != srcType)
1189 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1190 diag <<
"invalid dst type " << op.getType();
1193 OpFoldResult offset =
1194 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1197 rewriter.replaceOp(op, src);
1201 Type intType = converter->convertType(rewriter.getIndexType());
1203 return rewriter.notifyMatchFailure(op,
"failed to convert index type");
1205 Location loc = op.getLoc();
1206 auto offsetValue = [&]() -> Value {
1207 if (
auto val = dyn_cast<Value>(offset))
1210 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1211 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1212 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1215 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1224LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1225 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1226 ConversionPatternRewriter &rewriter)
const {
1227 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1228 Type indexType = typeConverter.getIndexType();
1229 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1230 adaptor.getSource());
1241 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1242 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1243 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1244 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1245 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static spirv::MemorySemantics getMemorySemanticsForStorageClass(spirv::StorageClass sc)
Returns the MemorySemantics storage-class bit corresponding to sc.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
static spirv::MemorySemantics getAtomicAcqRelMemorySemantics(MemRefType type)
Returns the AcquireRelease memory semantics OR'd with the storage-class bit derived from the memory s...
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess