28#define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (
auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
169static spirv::MemorySemantics
172 case spirv::StorageClass::StorageBuffer:
173 case spirv::StorageClass::Uniform:
174 return spirv::MemorySemantics::UniformMemory;
175 case spirv::StorageClass::Workgroup:
176 return spirv::MemorySemantics::WorkgroupMemory;
177 case spirv::StorageClass::CrossWorkgroup:
178 return spirv::MemorySemantics::CrossWorkgroupMemory;
179 case spirv::StorageClass::AtomicCounter:
180 return spirv::MemorySemantics::AtomicCounterMemory;
181 case spirv::StorageClass::Image:
182 return spirv::MemorySemantics::ImageMemory;
184 return spirv::MemorySemantics::None;
191 auto sc = cast<spirv::StorageClassAttr>(type.getMemorySpace()).getValue();
192 return spirv::MemorySemantics::AcquireRelease |
205 if (typeConverter.
allows(spirv::Capability::Kernel)) {
206 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
207 return arrayType.getElementType();
211 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
212 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
213 return arrayType.getElementType();
214 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
222 auto one = spirv::ConstantOp::getZero(srcInt.
getType(), loc, builder);
223 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
237class AllocaOpPattern final :
public OpConversionPattern<memref::AllocaOp> {
242 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter)
const override;
250class AllocOpPattern final :
public OpConversionPattern<memref::AllocOp> {
255 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter)
const override;
260class AtomicRMWOpPattern final
261 :
public OpConversionPattern<memref::AtomicRMWOp> {
266 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter)
const override;
272class DeallocOpPattern final :
public OpConversionPattern<memref::DeallocOp> {
277 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override;
282class IntLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
287 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override;
292class LoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
297 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter)
const override;
302class ImageLoadOpPattern final :
public OpConversionPattern<memref::LoadOp> {
307 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter)
const override;
312class IntStoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
317 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter)
const override;
322class MemorySpaceCastOpPattern final
323 :
public OpConversionPattern<memref::MemorySpaceCastOp> {
328 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter)
const override;
333class StoreOpPattern final :
public OpConversionPattern<memref::StoreOp> {
338 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override;
342class ReinterpretCastPattern final
343 :
public OpConversionPattern<memref::ReinterpretCastOp> {
348 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
352class CastPattern final :
public OpConversionPattern<memref::CastOp> {
357 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
359 Value src = adaptor.getSource();
362 const TypeConverter *converter = getTypeConverter();
363 Type dstType = converter->convertType(op.getType());
364 if (srcType != dstType)
365 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
366 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
369 rewriter.replaceOp(op, src);
375class ExtractAlignedPointerAsIndexOpPattern final
376 :
public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
381 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
383 ConversionPatternRewriter &rewriter)
const override;
392AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter)
const {
394 MemRefType allocType = allocaOp.getType();
396 return rewriter.notifyMatchFailure(allocaOp,
"unhandled allocation type");
399 Type spirvType = getTypeConverter()->convertType(allocType);
401 return rewriter.notifyMatchFailure(allocaOp,
"type conversion failed");
403 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
404 spirv::StorageClass::Function,
414AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter)
const {
416 MemRefType allocType = operation.getType();
418 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
421 Type spirvType = getTypeConverter()->convertType(allocType);
423 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
430 Location loc = operation.getLoc();
431 spirv::GlobalVariableOp varOp;
433 OpBuilder::InsertionGuard guard(rewriter);
435 rewriter.setInsertionPointToStart(&entryBlock);
436 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
437 std::string varName =
438 std::string(
"__workgroup_mem__") +
439 std::to_string(std::distance(varOps.begin(), varOps.end()));
440 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
445 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
454AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
456 ConversionPatternRewriter &rewriter)
const {
457 if (isa<FloatType>(atomicOp.getType()))
458 return rewriter.notifyMatchFailure(atomicOp,
459 "unimplemented floating-point case");
461 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
464 return rewriter.notifyMatchFailure(atomicOp,
465 "unsupported memref memory space");
467 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
468 Type resultType = typeConverter.convertType(atomicOp.getType());
470 return rewriter.notifyMatchFailure(atomicOp,
471 "failed to convert result type");
473 auto loc = atomicOp.getLoc();
476 adaptor.getIndices(), loc, rewriter);
484 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
485 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
487 return rewriter.notifyMatchFailure(atomicOp,
488 "failed to convert memref type");
490 Type pointeeType = pointerType.getPointeeType();
491 auto dstType = dyn_cast<IntegerType>(
494 return rewriter.notifyMatchFailure(
495 atomicOp,
"failed to determine destination element type");
497 int dstBits =
static_cast<int>(dstType.getWidth());
498 assert(dstBits % srcBits == 0);
504 if (srcBits == dstBits) {
505#define ATOMIC_CASE(kind, spirvOp) \
506 case arith::AtomicRMWKind::kind: \
507 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
508 atomicOp, resultType, ptr, *scope, memSem, adaptor.getValue()); \
511 switch (atomicOp.getKind()) {
520 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
535 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
536 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
537 return rewriter.notifyMatchFailure(
539 "atomic op on sub-element-width types is only supported for ori/andi");
544 if (typeConverter.allows(spirv::Capability::Kernel))
545 return rewriter.notifyMatchFailure(
547 "sub-element-width atomic ops unsupported with Kernel capability");
549 auto accessChainOp = ptr.
getDefiningOp<spirv::AccessChainOp>();
555 assert(accessChainOp.getIndices().size() == 2);
556 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
559 srcBits, dstBits, rewriter);
561 switch (atomicOp.getKind()) {
562 case arith::AtomicRMWKind::ori: {
565 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
566 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
568 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
569 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
570 *scope, memSem, storeVal);
573 case arith::AtomicRMWKind::andi: {
577 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
578 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
580 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
581 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
582 loc, dstType, elemMask, offset);
583 Value invertedElemMask =
584 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
585 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
587 result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
588 *scope, memSem, mask);
592 return rewriter.notifyMatchFailure(atomicOp,
"unimplemented atomic kind");
597 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
599 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
600 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
602 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
603 rewriter.replaceOp(atomicOp,
result);
613DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
615 ConversionPatternRewriter &rewriter)
const {
616 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
618 return rewriter.notifyMatchFailure(operation,
"unhandled allocation type");
619 rewriter.eraseOp(operation);
634static FailureOr<MemoryRequirements>
636 uint64_t preferredAlignment) {
637 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
643 auto memoryAccess = spirv::MemoryAccess::None;
645 memoryAccess = spirv::MemoryAccess::Nontemporal;
648 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
649 bool mayOmitAlignment =
650 !preferredAlignment &&
651 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
652 if (mayOmitAlignment) {
653 if (memoryAccess == spirv::MemoryAccess::None) {
662 std::optional<int64_t> sizeInBytes;
663 Type rawPointeeType = ptrType.getPointeeType();
664 if (
auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
666 sizeInBytes = scalarType.getSizeInBytes();
667 }
else if (
auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
670 if (
auto scalarElem =
671 dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
672 if (
auto elemSize = scalarElem.getSizeInBytes())
673 sizeInBytes = *elemSize * vecType.getNumElements();
677 if (!sizeInBytes.has_value())
680 memoryAccess |= spirv::MemoryAccess::Aligned;
681 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
682 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
683 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
690template <
class LoadOrStoreOp>
691static FailureOr<MemoryRequirements>
694 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
695 "Must be called on either memref::LoadOp or memref::StoreOp");
698 loadOrStoreOp.getNontemporal(),
699 loadOrStoreOp.getAlignment().value_or(0));
703IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
704 ConversionPatternRewriter &rewriter)
const {
705 auto loc = loadOp.getLoc();
706 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
707 if (!memrefType.getElementType().isSignlessInteger())
710 auto memorySpaceAttr =
711 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
712 if (!memorySpaceAttr)
713 return rewriter.notifyMatchFailure(
714 loadOp,
"missing memory space SPIR-V storage class attribute");
716 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
717 return rewriter.notifyMatchFailure(
719 "failed to lower memref in image storage class to storage buffer");
721 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
724 adaptor.getIndices(), loc, rewriter);
729 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
730 bool isBool = srcBits == 1;
732 srcBits = typeConverter.getOptions().boolNumBits;
734 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
736 return rewriter.notifyMatchFailure(loadOp,
"failed to convert memref type");
738 Type pointeeType = pointerType.getPointeeType();
741 assert(dstBits % srcBits == 0);
745 if (srcBits == dstBits) {
747 if (
failed(memoryRequirements))
748 return rewriter.notifyMatchFailure(
749 loadOp,
"failed to determine memory requirements");
751 auto [memoryAccess, alignment] = *memoryRequirements;
752 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
753 memoryAccess, alignment);
756 rewriter.replaceOp(loadOp, loadVal);
762 if (typeConverter.allows(spirv::Capability::Kernel))
765 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
772 assert(accessChainOp.getIndices().size() == 2);
774 srcBits, dstBits, rewriter);
776 if (
failed(memoryRequirements))
777 return rewriter.notifyMatchFailure(
778 loadOp,
"failed to determine memory requirements");
780 auto [memoryAccess, alignment] = *memoryRequirements;
781 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
782 memoryAccess, alignment);
786 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
788 Value
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
789 loc, spvLoadOp.
getType(), spvLoadOp, offset);
792 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
793 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
795 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType,
result, mask);
800 IntegerAttr shiftValueAttr =
801 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
803 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
804 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
806 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
809 rewriter.replaceOp(loadOp,
result);
811 assert(accessChainOp.use_empty());
812 rewriter.eraseOp(accessChainOp);
818LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
819 ConversionPatternRewriter &rewriter)
const {
820 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
821 if (memrefType.getElementType().isSignlessInteger())
824 auto memorySpaceAttr =
825 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
826 if (!memorySpaceAttr)
827 return rewriter.notifyMatchFailure(
828 loadOp,
"missing memory space SPIR-V storage class attribute");
830 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
831 return rewriter.notifyMatchFailure(
833 "failed to lower memref in image storage class to storage buffer");
836 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
837 adaptor.getIndices(), loadOp.getLoc(), rewriter);
843 if (
failed(memoryRequirements))
844 return rewriter.notifyMatchFailure(
845 loadOp,
"failed to determine memory requirements");
847 auto [memoryAccess, alignment] = *memoryRequirements;
848 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
853template <
typename OpAdaptor>
854static FailureOr<SmallVector<Value>>
856 ConversionPatternRewriter &rewriter) {
863 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
865 return rewriter.notifyMatchFailure(
867 "Cannot lower memrefs with memory layout which is not a permutation");
873 for (
unsigned dim = 0; dim < dimCount; ++dim)
879 return llvm::to_vector(llvm::reverse(coords));
883ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
884 ConversionPatternRewriter &rewriter)
const {
885 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
887 auto memorySpaceAttr =
888 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
889 if (!memorySpaceAttr)
890 return rewriter.notifyMatchFailure(
891 loadOp,
"missing memory space SPIR-V storage class attribute");
893 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
894 return rewriter.notifyMatchFailure(
895 loadOp,
"failed to lower memref in non-image storage class to image");
897 Value loadPtr = adaptor.getMemref();
899 if (
failed(memoryRequirements))
900 return rewriter.notifyMatchFailure(
901 loadOp,
"failed to determine memory requirements");
903 const auto [memoryAccess, alignment] = *memoryRequirements;
905 if (!loadOp.getMemRefType().hasRank())
906 return rewriter.notifyMatchFailure(
907 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
912 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
913 return rewriter.notifyMatchFailure(
915 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
922 auto convertedPointeeType = cast<spirv::PointerType>(
923 getTypeConverter()->convertType(loadOp.getMemRefType()));
924 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
925 return rewriter.notifyMatchFailure(loadOp,
926 "cannot lower memrefs which do not "
927 "convert to SPIR-V sampled images");
930 Location loc = loadOp->getLoc();
932 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
934 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
938 if (memrefType.getRank() == 1) {
939 coords = adaptor.getIndices()[0];
941 FailureOr<SmallVector<Value>> maybeCoords =
945 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
946 adaptor.getIndices().
getType()[0]);
947 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
948 maybeCoords.value());
952 auto resultVectorType = VectorType::get({4}, loadOp.getType());
953 auto fetchOp = spirv::ImageFetchOp::create(
954 rewriter, loc, resultVectorType, imageOp, coords,
955 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
960 auto compositeExtractOp =
961 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
963 rewriter.replaceOp(loadOp, compositeExtractOp);
968IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
969 ConversionPatternRewriter &rewriter)
const {
970 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
971 if (!memrefType.getElementType().isSignlessInteger())
972 return rewriter.notifyMatchFailure(storeOp,
973 "element type is not a signless int");
975 auto loc = storeOp.getLoc();
976 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
979 adaptor.getIndices(), loc, rewriter);
982 return rewriter.notifyMatchFailure(
983 storeOp,
"failed to convert element pointer type");
985 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
987 bool isBool = srcBits == 1;
989 srcBits = typeConverter.getOptions().boolNumBits;
991 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
993 return rewriter.notifyMatchFailure(storeOp,
994 "failed to convert memref type");
996 Type pointeeType = pointerType.getPointeeType();
997 auto dstType = dyn_cast<IntegerType>(
1000 return rewriter.notifyMatchFailure(
1001 storeOp,
"failed to determine destination element type");
1003 int dstBits =
static_cast<int>(dstType.getWidth());
1004 assert(dstBits % srcBits == 0);
1006 if (srcBits == dstBits) {
1008 if (
failed(memoryRequirements))
1009 return rewriter.notifyMatchFailure(
1010 storeOp,
"failed to determine memory requirements");
1012 auto [memoryAccess, alignment] = *memoryRequirements;
1013 Value storeVal = adaptor.getValue();
1016 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
1017 memoryAccess, alignment);
1023 if (typeConverter.allows(spirv::Capability::Kernel))
1026 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
1041 assert(accessChainOp.getIndices().size() == 2);
1042 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1047 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1048 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1049 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1050 loc, dstType, mask, offset);
1052 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1054 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1056 srcBits, dstBits, rewriter);
1059 return rewriter.notifyMatchFailure(storeOp,
"atomic scope not available");
1062 Value
result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
1063 *scope, memSem, clearBitsMask);
1064 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
1065 *scope, memSem, storeVal);
1071 rewriter.eraseOp(storeOp);
1073 assert(accessChainOp.use_empty());
1074 rewriter.eraseOp(accessChainOp);
1083LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1084 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1085 ConversionPatternRewriter &rewriter)
const {
1086 Location loc = addrCastOp.getLoc();
1087 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1088 if (!typeConverter.allows(spirv::Capability::Kernel))
1089 return rewriter.notifyMatchFailure(
1090 loc,
"address space casts require kernel capability");
1092 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1094 return rewriter.notifyMatchFailure(
1095 loc,
"SPIR-V lowering requires ranked memref types");
1096 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1098 auto sourceStorageClassAttr =
1099 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1100 if (!sourceStorageClassAttr)
1101 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &
diag) {
1102 diag <<
"source address space " << sourceType.getMemorySpace()
1103 <<
" must be a SPIR-V storage class";
1105 auto resultStorageClassAttr =
1106 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1107 if (!resultStorageClassAttr)
1108 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &
diag) {
1109 diag <<
"result address space " << resultType.getMemorySpace()
1110 <<
" must be a SPIR-V storage class";
1113 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1114 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1116 Value
result = adaptor.getSource();
1117 Type resultPtrType = typeConverter.convertType(resultType);
1119 return rewriter.notifyMatchFailure(addrCastOp,
1120 "failed to convert memref type");
1122 Type genericPtrType = resultPtrType;
1130 if (sourceSc != spirv::StorageClass::Generic &&
1131 resultSc != spirv::StorageClass::Generic) {
1132 Type intermediateType =
1133 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1134 sourceType.getLayout(),
1135 rewriter.getAttr<spirv::StorageClassAttr>(
1136 spirv::StorageClass::Generic));
1137 genericPtrType = typeConverter.convertType(intermediateType);
1139 if (sourceSc != spirv::StorageClass::Generic) {
1140 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1143 if (resultSc != spirv::StorageClass::Generic) {
1145 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType,
result);
1147 rewriter.replaceOp(addrCastOp,
result);
1152StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1153 ConversionPatternRewriter &rewriter)
const {
1154 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1155 if (memrefType.getElementType().isSignlessInteger())
1156 return rewriter.notifyMatchFailure(storeOp,
"signless int");
1158 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1159 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1162 return rewriter.notifyMatchFailure(storeOp,
"type conversion failed");
1165 if (
failed(memoryRequirements))
1166 return rewriter.notifyMatchFailure(
1167 storeOp,
"failed to determine memory requirements");
1169 auto [memoryAccess, alignment] = *memoryRequirements;
1170 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1171 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1175LogicalResult ReinterpretCastPattern::matchAndRewrite(
1176 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1177 ConversionPatternRewriter &rewriter)
const {
1178 Value src = adaptor.getSource();
1179 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1182 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1186 const TypeConverter *converter = getTypeConverter();
1188 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1189 if (dstType != srcType)
1190 return rewriter.notifyMatchFailure(op, [&](Diagnostic &
diag) {
1191 diag <<
"invalid dst type " << op.getType();
1194 OpFoldResult offset =
1195 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1198 rewriter.replaceOp(op, src);
1202 Type intType = converter->convertType(rewriter.getIndexType());
1204 return rewriter.notifyMatchFailure(op,
"failed to convert index type");
1206 Location loc = op.getLoc();
1207 auto offsetValue = [&]() -> Value {
1208 if (
auto val = dyn_cast<Value>(offset))
1211 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1212 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1213 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1216 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1225LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1226 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1227 ConversionPatternRewriter &rewriter)
const {
1228 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1229 Type indexType = typeConverter.getIndexType();
1230 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1231 adaptor.getSource());
1242 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1243 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1244 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1245 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1246 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static spirv::MemorySemantics getMemorySemanticsForStorageClass(spirv::StorageClass sc)
Returns the MemorySemantics storage-class bit corresponding to sc.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
static spirv::MemorySemantics getAtomicAcqRelMemorySemantics(MemRefType type)
Returns the AcquireRelease memory semantics OR'd with the storage-class bit derived from the memory s...
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess