27 #define DEBUG_TYPE "memref-to-spirv-pattern"
48 assert(targetBits % sourceBits == 0);
50 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
51 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
52 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
54 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
55 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
56 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
69 spirv::AccessChainOp op,
int sourceBits,
71 assert(targetBits % sourceBits == 0);
72 const auto loc = op.getLoc();
73 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
76 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
77 auto indices = llvm::to_vector<4>(op.getIndices());
79 assert(indices.size() == 2);
80 indices.back() = builder.
createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 return builder.
create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
92 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
93 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
101 IntegerType dstType = cast<IntegerType>(mask.
getType());
102 int targetBits =
static_cast<int>(dstType.getWidth());
104 assert(valueBits <= targetBits);
106 if (valueBits == 1) {
109 if (valueBits < targetBits) {
110 value = builder.
create<spirv::UConvertOp>(
114 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
123 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
124 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
125 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127 }
else if (isa<memref::AllocaOp>(allocOp)) {
128 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
129 if (!sc || sc.getValue() != spirv::StorageClass::Function)
137 if (!type.hasStaticShape())
140 Type elementType = type.getElementType();
141 if (
auto vecType = dyn_cast<VectorType>(elementType))
142 elementType = vecType.getElementType();
150 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
151 switch (sc.getValue()) {
152 case spirv::StorageClass::StorageBuffer:
153 return spirv::Scope::Device;
154 case spirv::StorageClass::Workgroup:
155 return spirv::Scope::Workgroup;
168 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
187 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
200 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
205 class AtomicRMWOpPattern final
211 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
222 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
232 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
242 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
252 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
257 class MemorySpaceCastOpPattern final
263 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
273 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
277 class ReinterpretCastPattern final
283 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
292 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294 Value src = adaptor.getSource();
299 if (srcType != dstType)
301 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
310 class ExtractAlignedPointerAsIndexOpPattern final
316 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
327 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
329 MemRefType allocType = allocaOp.getType();
334 Type spirvType = getTypeConverter()->convertType(allocType);
339 spirv::StorageClass::Function,
349 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
351 MemRefType allocType = operation.getType();
356 Type spirvType = getTypeConverter()->convertType(allocType);
366 spirv::GlobalVariableOp varOp;
371 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
372 std::string varName =
373 std::string(
"__workgroup_mem__") +
374 std::to_string(std::distance(varOps.begin(), varOps.end()));
375 varOp = rewriter.
create<spirv::GlobalVariableOp>(loc, spirvType, varName,
389 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
392 if (isa<FloatType>(atomicOp.getType()))
394 "unimplemented floating-point case");
396 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
400 "unsupported memref memory space");
402 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
403 Type resultType = typeConverter.convertType(atomicOp.getType());
406 "failed to convert result type");
408 auto loc = atomicOp.getLoc();
411 adaptor.getIndices(), loc, rewriter);
416 #define ATOMIC_CASE(kind, spirvOp) \
417 case arith::AtomicRMWKind::kind: \
418 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
419 atomicOp, resultType, ptr, *scope, \
420 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
423 switch (atomicOp.getKind()) {
445 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
448 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
466 static FailureOr<MemoryRequirements>
472 memoryAccess = spirv::MemoryAccess::Nontemporal;
475 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
476 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
485 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
490 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
491 if (!sizeInBytes.has_value())
494 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
503 template <
class LoadOrStoreOp>
504 static FailureOr<MemoryRequirements>
507 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
508 "Must be called on either memref::LoadOp or memref::StoreOp");
510 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
511 auto memrefMemAccess = memrefAccessOp->
getAttrOfType<spirv::MemoryAccessAttr>(
512 spirv::attributeName<spirv::MemoryAccess>());
513 auto memrefAlignment =
515 if (memrefMemAccess && memrefAlignment)
519 loadOrStoreOp.getNontemporal());
523 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
525 auto loc = loadOp.getLoc();
526 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
527 if (!memrefType.getElementType().isSignlessInteger())
530 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
533 adaptor.getIndices(), loc, rewriter);
538 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
539 bool isBool = srcBits == 1;
541 srcBits = typeConverter.getOptions().boolNumBits;
547 Type pointeeType = pointerType.getPointeeType();
549 if (typeConverter.allows(spirv::Capability::Kernel)) {
550 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
551 dstType = arrayType.getElementType();
553 dstType = pointeeType;
556 Type structElemType =
557 cast<spirv::StructType>(pointeeType).getElementType(0);
558 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
559 dstType = arrayType.getElementType();
561 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
564 assert(dstBits % srcBits == 0);
568 if (srcBits == dstBits) {
570 if (failed(memoryRequirements))
572 loadOp,
"failed to determine memory requirements");
574 auto [memoryAccess, alignment] = *memoryRequirements;
575 Value loadVal = rewriter.
create<spirv::LoadOp>(loc, accessChain,
576 memoryAccess, alignment);
585 if (typeConverter.allows(spirv::Capability::Kernel))
588 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
595 assert(accessChainOp.getIndices().size() == 2);
597 srcBits, dstBits, rewriter);
599 if (failed(memoryRequirements))
601 loadOp,
"failed to determine memory requirements");
603 auto [memoryAccess, alignment] = *memoryRequirements;
604 Value spvLoadOp = rewriter.
create<spirv::LoadOp>(loc, dstType, adjustedPtr,
605 memoryAccess, alignment);
609 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
612 loc, spvLoadOp.
getType(), spvLoadOp, offset);
616 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
618 rewriter.
createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
623 IntegerAttr shiftValueAttr =
626 rewriter.
createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
627 result = rewriter.
createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
629 result = rewriter.
createOrFold<spirv::ShiftRightArithmeticOp>(
634 assert(accessChainOp.use_empty());
635 rewriter.
eraseOp(accessChainOp);
641 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
643 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
644 if (memrefType.getElementType().isSignlessInteger())
647 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
648 adaptor.getIndices(), loadOp.getLoc(), rewriter);
654 if (failed(memoryRequirements))
656 loadOp,
"failed to determine memory requirements");
658 auto [memoryAccess, alignment] = *memoryRequirements;
665 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
667 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
668 if (!memrefType.getElementType().isSignlessInteger())
670 "element type is not a signless int");
672 auto loc = storeOp.getLoc();
673 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
676 adaptor.getIndices(), loc, rewriter);
680 storeOp,
"failed to convert element pointer type");
682 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
684 bool isBool = srcBits == 1;
686 srcBits = typeConverter.getOptions().boolNumBits;
691 "failed to convert memref type");
693 Type pointeeType = pointerType.getPointeeType();
695 if (typeConverter.allows(spirv::Capability::Kernel)) {
696 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
697 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
699 dstType = dyn_cast<IntegerType>(pointeeType);
702 Type structElemType =
703 cast<spirv::StructType>(pointeeType).getElementType(0);
704 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
705 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
707 dstType = dyn_cast<IntegerType>(
713 storeOp,
"failed to determine destination element type");
715 int dstBits =
static_cast<int>(dstType.getWidth());
716 assert(dstBits % srcBits == 0);
718 if (srcBits == dstBits) {
720 if (failed(memoryRequirements))
722 storeOp,
"failed to determine memory requirements");
724 auto [memoryAccess, alignment] = *memoryRequirements;
725 Value storeVal = adaptor.getValue();
729 memoryAccess, alignment);
735 if (typeConverter.allows(spirv::Capability::Kernel))
738 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
753 assert(accessChainOp.getIndices().size() == 2);
754 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
760 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
762 loc, dstType, mask, offset);
764 rewriter.
createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
766 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
768 srcBits, dstBits, rewriter);
773 Value result = rewriter.
create<spirv::AtomicAndOp>(
774 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
776 result = rewriter.
create<spirv::AtomicOrOp>(
777 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
786 assert(accessChainOp.use_empty());
787 rewriter.
eraseOp(accessChainOp);
796 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
797 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
800 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
801 if (!typeConverter.allows(spirv::Capability::Kernel))
803 loc,
"address space casts require kernel capability");
805 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
808 loc,
"SPIR-V lowering requires ranked memref types");
809 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
811 auto sourceStorageClassAttr =
812 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
813 if (!sourceStorageClassAttr)
815 diag <<
"source address space " << sourceType.getMemorySpace()
816 <<
" must be a SPIR-V storage class";
818 auto resultStorageClassAttr =
819 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
820 if (!resultStorageClassAttr)
822 diag <<
"result address space " << resultType.getMemorySpace()
823 <<
" must be a SPIR-V storage class";
826 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
827 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
829 Value result = adaptor.getSource();
830 Type resultPtrType = typeConverter.convertType(resultType);
833 "failed to convert memref type");
835 Type genericPtrType = resultPtrType;
843 if (sourceSc != spirv::StorageClass::Generic &&
844 resultSc != spirv::StorageClass::Generic) {
845 Type intermediateType =
847 sourceType.getLayout(),
848 rewriter.
getAttr<spirv::StorageClassAttr>(
849 spirv::StorageClass::Generic));
850 genericPtrType = typeConverter.convertType(intermediateType);
852 if (sourceSc != spirv::StorageClass::Generic) {
854 rewriter.
create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
856 if (resultSc != spirv::StorageClass::Generic) {
858 rewriter.
create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
865 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
867 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
868 if (memrefType.getElementType().isSignlessInteger())
871 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
872 adaptor.getIndices(), storeOp.getLoc(), rewriter);
878 if (failed(memoryRequirements))
880 storeOp,
"failed to determine memory requirements");
882 auto [memoryAccess, alignment] = *memoryRequirements;
884 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
888 LogicalResult ReinterpretCastPattern::matchAndRewrite(
889 memref::ReinterpretCastOp op, OpAdaptor adaptor,
891 Value src = adaptor.getSource();
892 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
902 if (dstType != srcType)
904 diag <<
"invalid dst type " << op.getType();
908 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
920 auto offsetValue = [&]() ->
Value {
921 if (
auto val = dyn_cast<Value>(offset))
924 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
926 return rewriter.
createOrFold<spirv::ConstantOp>(loc, intType, attr);
938 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
939 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
941 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
942 Type indexType = typeConverter.getIndexType();
944 adaptor.getSource());
956 .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
957 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
958 MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
959 CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
960 typeConverter,
patterns.getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess