28#define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (
auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
175 if (typeConverter.
allows(spirv::Capability::Kernel)) {
176 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
177 return arrayType.getElementType();
181 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
182 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
183 return arrayType.getElementType();
184 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
192 auto one = spirv::ConstantOp::getZero(srcInt.
getType(), loc, builder);
193 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
207class AllocaOpPattern final :
public OpConversionPattern<memref::AllocaOp> {
212 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter)
const override;
220class AllocOpPattern final :
public OpConversionPattern<memref::AllocOp> {
225 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter)
const override;
230class AtomicRMWOpPattern final
231 :
public OpConversionPattern<memref::AtomicRMWOp> {
236 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter)
const override;
242class DeallocOpPattern final :
public OpConversionPattern<memref::DeallocOp> {
247 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
248 ConversionPatternRewriter &rewriter)
const override;
252class IntLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
257 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter)
const override;
262class LoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
267 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
268 ConversionPatternRewriter &rewriter)
const override;
272class ImageLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
277 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override;
282class IntStoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
287 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override;
292class MemorySpaceCastOpPattern final
293 :
public OpConversionPattern<memref::MemorySpaceCastOp> {
298 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter)
const override;
303class StoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
308 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
309 ConversionPatternRewriter &rewriter)
const override;
312class ReinterpretCastPattern final
313 :
public OpConversionPattern<memref::ReinterpretCastOp> {
318 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter)
const override;
322class CastPattern final :
public OpConversionPattern<memref::CastOp> {
327 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter)
const override {
329 Value src = adaptor.getSource();
332 const TypeConverter *converter = getTypeConverter();
333 Type dstType = converter->convertType(op.getType());
334 if (srcType != dstType)
335 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
336 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
339 rewriter.replaceOp(op, src);
345class ExtractAlignedPointerAsIndexOpPattern final
346 :
public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
351 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
353 ConversionPatternRewriter &rewriter)
const override;
362AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const {
364 MemRefType allocType = allocaOp.getType();
366 return rewriter.notifyMatchFailure(allocaOp,
"unhandled allocation type");
369 Type spirvType = getTypeConverter()->convertType(allocType);
371 return rewriter.notifyMatchFailure(allocaOp,
"type conversion failed");
373 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
374 spirv::StorageClass::Function,
384AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
385 ConversionPatternRewriter &rewriter)
const {
386 MemRefType allocType = operation.getType();
388 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
391 Type spirvType = getTypeConverter()->convertType(allocType);
393 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
400 Location loc = operation.getLoc();
401 spirv::GlobalVariableOp varOp;
403 OpBuilder::InsertionGuard guard(rewriter);
405 rewriter.setInsertionPointToStart(&entryBlock);
406 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
407 std::string varName =
408 std::string(
"__workgroup_mem__") +
409 std::to_string(std::distance(varOps.begin(), varOps.end()));
410 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
415 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
424AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
426 ConversionPatternRewriter &rewriter)
const {
427 if (isa<FloatType>(atomicOp.getType()))
428 return rewriter.notifyMatchFailure(atomicOp,
429 "unimplemented floating-point case");
431 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
434 return rewriter.notifyMatchFailure(atomicOp,
435 "unsupported memref memory space");
437 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
438 Type resultType = typeConverter.convertType(atomicOp.getType());
440 return rewriter.notifyMatchFailure(atomicOp,
441 "failed to convert result type");
443 auto loc = atomicOp.getLoc();
446 adaptor.getIndices(), loc, rewriter);
454 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
455 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
457 return rewriter.notifyMatchFailure(atomicOp,
458 "failed to convert memref type");
460 Type pointeeType = pointerType.getPointeeType();
461 auto dstType = dyn_cast<IntegerType>(
464 return rewriter.notifyMatchFailure(
465 atomicOp,
"failed to determine destination element type");
467 int dstBits =
static_cast<int>(dstType.getWidth());
468 assert(dstBits % srcBits == 0);
472 if (srcBits == dstBits) {
473#define ATOMIC_CASE(kind, spirvOp) \
474 case arith::AtomicRMWKind::kind: \
475 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
476 atomicOp, resultType, ptr, *scope, \
477 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
480 switch (atomicOp.getKind()) {
489 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
504 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
505 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
506 return rewriter.notifyMatchFailure(
508 "atomic op on sub-element-width types is only supported for ori/andi");
513 if (typeConverter.allows(spirv::Capability::Kernel))
514 return rewriter.notifyMatchFailure(
516 "sub-element-width atomic ops unsupported with Kernel capability");
518 auto accessChainOp = ptr.
getDefiningOp<spirv::AccessChainOp>();
524 assert(accessChainOp.getIndices().size() == 2);
525 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
528 srcBits, dstBits, rewriter);
530 switch (atomicOp.getKind()) {
531 case arith::AtomicRMWKind::ori: {
534 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
535 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
537 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
538 result = spirv::AtomicOrOp::create(
539 rewriter, loc, dstType, adjustedPtr, *scope,
540 spirv::MemorySemantics::AcquireRelease, storeVal);
543 case arith::AtomicRMWKind::andi: {
547 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
548 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
550 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
551 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
552 loc, dstType, elemMask, offset);
553 Value invertedElemMask =
554 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
555 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
557 result = spirv::AtomicAndOp::create(
558 rewriter, loc, dstType, adjustedPtr, *scope,
559 spirv::MemorySemantics::AcquireRelease, mask);
563 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
568 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
570 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
571 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
573 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
574 rewriter.replaceOp(atomicOp,
result);
584DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
586 ConversionPatternRewriter &rewriter)
const {
587 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
589 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
590 rewriter.eraseOp(operation);
605static FailureOr<MemoryRequirements>
607 uint64_t preferredAlignment) {
608 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
614 auto memoryAccess = spirv::MemoryAccess::None;
616 memoryAccess = spirv::MemoryAccess::Nontemporal;
619 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
620 bool mayOmitAlignment =
621 !preferredAlignment &&
622 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
623 if (mayOmitAlignment) {
624 if (memoryAccess == spirv::MemoryAccess::None) {
633 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
638 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
639 if (!sizeInBytes.has_value())
642 memoryAccess |= spirv::MemoryAccess::Aligned;
643 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
644 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
645 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
652template <
class LoadOrStoreOp>
653static FailureOr<MemoryRequirements>
656 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
657 "Must be called on either memref::LoadOp or memref::StoreOp");
660 loadOrStoreOp.getNontemporal(),
661 loadOrStoreOp.getAlignment().value_or(0));
665IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter)
const {
667 auto loc = loadOp.getLoc();
668 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
669 if (!memrefType.getElementType().isSignlessInteger())
672 auto memorySpaceAttr =
673 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
674 if (!memorySpaceAttr)
675 return rewriter.notifyMatchFailure(
676 loadOp,
"missing memory space SPIR-V storage class attribute");
678 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
679 return rewriter.notifyMatchFailure(
681 "failed to lower memref in image storage class to storage buffer");
683 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
686 adaptor.getIndices(), loc, rewriter);
691 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
692 bool isBool = srcBits == 1;
694 srcBits = typeConverter.getOptions().boolNumBits;
696 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
698 return rewriter.notifyMatchFailure(loadOp,
"failed to convert memref type");
700 Type pointeeType = pointerType.getPointeeType();
703 assert(dstBits % srcBits == 0);
707 if (srcBits == dstBits) {
709 if (
failed(memoryRequirements))
710 return rewriter.notifyMatchFailure(
711 loadOp,
"failed to determine memory requirements");
713 auto [memoryAccess, alignment] = *memoryRequirements;
714 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
715 memoryAccess, alignment);
718 rewriter.replaceOp(loadOp, loadVal);
724 if (typeConverter.allows(spirv::Capability::Kernel))
727 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
734 assert(accessChainOp.getIndices().size() == 2);
736 srcBits, dstBits, rewriter);
738 if (
failed(memoryRequirements))
739 return rewriter.notifyMatchFailure(
740 loadOp,
"failed to determine memory requirements");
742 auto [memoryAccess, alignment] = *memoryRequirements;
743 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
744 memoryAccess, alignment);
748 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
750 Value
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
751 loc, spvLoadOp.
getType(), spvLoadOp, offset);
754 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
755 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
757 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
762 IntegerAttr shiftValueAttr =
763 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
765 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
766 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
768 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
771 rewriter.replaceOp(loadOp,
result);
773 assert(accessChainOp.use_empty());
774 rewriter.eraseOp(accessChainOp);
780LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
781 ConversionPatternRewriter &rewriter)
const {
782 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
783 if (memrefType.getElementType().isSignlessInteger())
786 auto memorySpaceAttr =
787 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
788 if (!memorySpaceAttr)
789 return rewriter.notifyMatchFailure(
790 loadOp,
"missing memory space SPIR-V storage class attribute");
792 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
793 return rewriter.notifyMatchFailure(
795 "failed to lower memref in image storage class to storage buffer");
798 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
799 adaptor.getIndices(), loadOp.getLoc(), rewriter);
805 if (
failed(memoryRequirements))
806 return rewriter.notifyMatchFailure(
807 loadOp,
"failed to determine memory requirements");
809 auto [memoryAccess, alignment] = *memoryRequirements;
810 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
815template <
typename OpAdaptor>
816static FailureOr<SmallVector<Value>>
818 ConversionPatternRewriter &rewriter) {
825 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
827 return rewriter.notifyMatchFailure(
829 "Cannot lower memrefs with memory layout which is not a permutation");
835 for (
unsigned dim = 0; dim < dimCount; ++dim)
841 return llvm::to_vector(llvm::reverse(coords));
845ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
846 ConversionPatternRewriter &rewriter)
const {
847 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
849 auto memorySpaceAttr =
850 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
851 if (!memorySpaceAttr)
852 return rewriter.notifyMatchFailure(
853 loadOp,
"missing memory space SPIR-V storage class attribute");
855 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
856 return rewriter.notifyMatchFailure(
857 loadOp,
"failed to lower memref in non-image storage class to image");
859 Value loadPtr = adaptor.getMemref();
861 if (
failed(memoryRequirements))
862 return rewriter.notifyMatchFailure(
863 loadOp,
"failed to determine memory requirements");
865 const auto [memoryAccess, alignment] = *memoryRequirements;
867 if (!loadOp.getMemRefType().hasRank())
868 return rewriter.notifyMatchFailure(
869 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
874 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
875 return rewriter.notifyMatchFailure(
877 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
884 auto convertedPointeeType = cast<spirv::PointerType>(
885 getTypeConverter()->convertType(loadOp.getMemRefType()));
886 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
887 return rewriter.notifyMatchFailure(loadOp,
888 "cannot lower memrefs which do not "
889 "convert to SPIR-V sampled images");
892 Location loc = loadOp->getLoc();
894 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
896 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
900 if (memrefType.getRank() == 1) {
901 coords = adaptor.getIndices()[0];
903 FailureOr<SmallVector<Value>> maybeCoords =
907 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
908 adaptor.getIndices().
getType()[0]);
909 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
910 maybeCoords.value());
914 auto resultVectorType = VectorType::get({4}, loadOp.getType());
915 auto fetchOp = spirv::ImageFetchOp::create(
916 rewriter, loc, resultVectorType, imageOp, coords,
917 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
922 auto compositeExtractOp =
923 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
925 rewriter.replaceOp(loadOp, compositeExtractOp);
930IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
931 ConversionPatternRewriter &rewriter)
const {
932 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
933 if (!memrefType.getElementType().isSignlessInteger())
934 return rewriter.notifyMatchFailure(storeOp,
935 "element type is not a signless int");
937 auto loc = storeOp.getLoc();
938 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
941 adaptor.getIndices(), loc, rewriter);
944 return rewriter.notifyMatchFailure(
945 storeOp,
"failed to convert element pointer type");
947 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
949 bool isBool = srcBits == 1;
951 srcBits = typeConverter.getOptions().boolNumBits;
953 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
955 return rewriter.notifyMatchFailure(storeOp,
956 "failed to convert memref type");
958 Type pointeeType = pointerType.getPointeeType();
959 auto dstType = dyn_cast<IntegerType>(
962 return rewriter.notifyMatchFailure(
963 storeOp,
"failed to determine destination element type");
965 int dstBits =
static_cast<int>(dstType.getWidth());
966 assert(dstBits % srcBits == 0);
968 if (srcBits == dstBits) {
970 if (
failed(memoryRequirements))
971 return rewriter.notifyMatchFailure(
972 storeOp,
"failed to determine memory requirements");
974 auto [memoryAccess, alignment] = *memoryRequirements;
975 Value storeVal = adaptor.getValue();
978 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
979 memoryAccess, alignment);
985 if (typeConverter.allows(spirv::Capability::Kernel))
988 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
1003 assert(accessChainOp.getIndices().size() == 2);
1004 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1009 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1010 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1011 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1012 loc, dstType, mask, offset);
1014 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1016 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1018 srcBits, dstBits, rewriter);
1021 return rewriter.notifyMatchFailure(storeOp,
"atomic scope not available");
1023 Value
result = spirv::AtomicAndOp::create(
1024 rewriter, loc, dstType, adjustedPtr, *scope,
1025 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
1026 result = spirv::AtomicOrOp::create(
1027 rewriter, loc, dstType, adjustedPtr, *scope,
1028 spirv::MemorySemantics::AcquireRelease, storeVal);
1034 rewriter.eraseOp(storeOp);
1036 assert(accessChainOp.use_empty());
1037 rewriter.eraseOp(accessChainOp);
1046LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1047 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1048 ConversionPatternRewriter &rewriter)
const {
1049 Location loc = addrCastOp.getLoc();
1050 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1051 if (!typeConverter.allows(spirv::Capability::Kernel))
1052 return rewriter.notifyMatchFailure(
1053 loc,
"address space casts require kernel capability");
1055 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1057 return rewriter.notifyMatchFailure(
1058 loc,
"SPIR-V lowering requires ranked memref types");
1059 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1061 auto sourceStorageClassAttr =
1062 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1063 if (!sourceStorageClassAttr)
1064 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &
diag) {
1065 diag <<
"source address space " << sourceType.getMemorySpace()
1066 <<
" must be a SPIR-V storage class";
1068 auto resultStorageClassAttr =
1069 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1070 if (!resultStorageClassAttr)
1071 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &
diag) {
1072 diag <<
"result address space " << resultType.getMemorySpace()
1073 <<
" must be a SPIR-V storage class";
1076 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1077 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1079 Value
result = adaptor.getSource();
1080 Type resultPtrType = typeConverter.convertType(resultType);
1082 return rewriter.notifyMatchFailure(addrCastOp,
1083 "failed to convert memref type");
1085 Type genericPtrType = resultPtrType;
1093 if (sourceSc != spirv::StorageClass::Generic &&
1094 resultSc != spirv::StorageClass::Generic) {
1095 Type intermediateType =
1096 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1097 sourceType.getLayout(),
1098 rewriter.getAttr<spirv::StorageClassAttr>(
1099 spirv::StorageClass::Generic));
1100 genericPtrType = typeConverter.convertType(intermediateType);
1102 if (sourceSc != spirv::StorageClass::Generic) {
1103 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1106 if (resultSc != spirv::StorageClass::Generic) {
1108 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType,
result);
1110 rewriter.replaceOp(addrCastOp,
result);
1115StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1116 ConversionPatternRewriter &rewriter)
const {
1117 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1118 if (memrefType.getElementType().isSignlessInteger())
1119 return rewriter.notifyMatchFailure(storeOp,
"signless int");
1121 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1122 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1125 return rewriter.notifyMatchFailure(storeOp,
"type conversion failed");
1128 if (
failed(memoryRequirements))
1129 return rewriter.notifyMatchFailure(
1130 storeOp,
"failed to determine memory requirements");
1132 auto [memoryAccess, alignment] = *memoryRequirements;
1133 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1134 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1138LogicalResult ReinterpretCastPattern::matchAndRewrite(
1139 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1140 ConversionPatternRewriter &rewriter)
const {
1141 Value src = adaptor.getSource();
1142 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1145 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1149 const TypeConverter *converter = getTypeConverter();
1151 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1152 if (dstType != srcType)
1153 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1154 diag <<
"invalid dst type " << op.getType();
1157 OpFoldResult offset =
1158 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1161 rewriter.replaceOp(op, src);
1165 Type intType = converter->convertType(rewriter.getIndexType());
1167 return rewriter.notifyMatchFailure(op,
"failed to convert index type");
1169 Location loc = op.getLoc();
1170 auto offsetValue = [&]() -> Value {
1171 if (
auto val = dyn_cast<Value>(offset))
1174 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1175 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1176 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1179 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1188LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1189 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1190 ConversionPatternRewriter &rewriter)
const {
1191 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1192 Type indexType = typeConverter.getIndexType();
1193 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1194 adaptor.getSource());
1205 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1206 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1207 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1208 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1209 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess