24 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
80 assert(indices.size() == 2);
81 indices.back() = builder.
createOrFold<spirv::SDivOp>(loc, lastDim, idx);
83 return builder.
create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
102 IntegerType dstType = cast<IntegerType>(mask.
getType());
103 int targetBits =
static_cast<int>(dstType.getWidth());
105 assert(valueBits <= targetBits);
107 if (valueBits == 1) {
110 if (valueBits < targetBits) {
111 value = builder.
create<spirv::UConvertOp>(
115 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 }
else if (isa<memref::AllocaOp>(allocOp)) {
129 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130 if (!sc || sc.getValue() != spirv::StorageClass::Function)
138 if (!type.hasStaticShape())
141 Type elementType = type.getElementType();
142 if (
auto vecType = dyn_cast<VectorType>(elementType))
143 elementType = vecType.getElementType();
151 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152 switch (sc.getValue()) {
153 case spirv::StorageClass::StorageBuffer:
154 return spirv::Scope::Device;
155 case spirv::StorageClass::Workgroup:
156 return spirv::Scope::Workgroup;
169 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
206 class AtomicRMWOpPattern final
212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
258 class MemorySpaceCastOpPattern final
264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
278 class ReinterpretCastPattern final
284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
295 Value src = adaptor.getSource();
300 if (srcType != dstType)
302 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
311 class ExtractAlignedPointerAsIndexOpPattern final
317 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
328 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
330 MemRefType allocType = allocaOp.getType();
335 Type spirvType = getTypeConverter()->convertType(allocType);
340 spirv::StorageClass::Function,
350 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
352 MemRefType allocType = operation.getType();
357 Type spirvType = getTypeConverter()->convertType(allocType);
367 spirv::GlobalVariableOp varOp;
372 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
373 std::string varName =
374 std::string(
"__workgroup_mem__") +
375 std::to_string(std::distance(varOps.begin(), varOps.end()));
376 varOp = rewriter.
create<spirv::GlobalVariableOp>(loc, spirvType, varName,
390 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
393 if (isa<FloatType>(atomicOp.getType()))
395 "unimplemented floating-point case");
397 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
401 "unsupported memref memory space");
403 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
404 Type resultType = typeConverter.convertType(atomicOp.getType());
407 "failed to convert result type");
409 auto loc = atomicOp.getLoc();
412 adaptor.getIndices(), loc, rewriter);
417 #define ATOMIC_CASE(kind, spirvOp) \
418 case arith::AtomicRMWKind::kind: \
419 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
420 atomicOp, resultType, ptr, *scope, \
421 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
424 switch (atomicOp.getKind()) {
446 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
449 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
467 static FailureOr<MemoryRequirements>
473 memoryAccess = spirv::MemoryAccess::Nontemporal;
476 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
477 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
486 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
491 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
492 if (!sizeInBytes.has_value())
495 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
504 template <
class LoadOrStoreOp>
505 static FailureOr<MemoryRequirements>
508 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
509 "Must be called on either memref::LoadOp or memref::StoreOp");
511 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
512 auto memrefMemAccess = memrefAccessOp->
getAttrOfType<spirv::MemoryAccessAttr>(
513 spirv::attributeName<spirv::MemoryAccess>());
514 auto memrefAlignment =
516 if (memrefMemAccess && memrefAlignment)
520 loadOrStoreOp.getNontemporal());
524 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
526 auto loc = loadOp.getLoc();
527 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
528 if (!memrefType.getElementType().isSignlessInteger())
531 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
534 adaptor.getIndices(), loc, rewriter);
539 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
540 bool isBool = srcBits == 1;
542 srcBits = typeConverter.getOptions().boolNumBits;
548 Type pointeeType = pointerType.getPointeeType();
550 if (typeConverter.allows(spirv::Capability::Kernel)) {
551 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
552 dstType = arrayType.getElementType();
554 dstType = pointeeType;
557 Type structElemType =
558 cast<spirv::StructType>(pointeeType).getElementType(0);
559 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
560 dstType = arrayType.getElementType();
562 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
565 assert(dstBits % srcBits == 0);
569 if (srcBits == dstBits) {
571 if (failed(memoryRequirements))
573 loadOp,
"failed to determine memory requirements");
575 auto [memoryAccess, alignment] = *memoryRequirements;
576 Value loadVal = rewriter.
create<spirv::LoadOp>(loc, accessChain,
577 memoryAccess, alignment);
586 if (typeConverter.allows(spirv::Capability::Kernel))
589 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
596 assert(accessChainOp.getIndices().size() == 2);
598 srcBits, dstBits, rewriter);
600 if (failed(memoryRequirements))
602 loadOp,
"failed to determine memory requirements");
604 auto [memoryAccess, alignment] = *memoryRequirements;
605 Value spvLoadOp = rewriter.
create<spirv::LoadOp>(loc, dstType, adjustedPtr,
606 memoryAccess, alignment);
610 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
613 loc, spvLoadOp.
getType(), spvLoadOp, offset);
617 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
619 rewriter.
createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
624 IntegerAttr shiftValueAttr =
627 rewriter.
createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
628 result = rewriter.
createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
630 result = rewriter.
createOrFold<spirv::ShiftRightArithmeticOp>(
635 assert(accessChainOp.use_empty());
636 rewriter.
eraseOp(accessChainOp);
642 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
644 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
645 if (memrefType.getElementType().isSignlessInteger())
648 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
649 adaptor.getIndices(), loadOp.getLoc(), rewriter);
655 if (failed(memoryRequirements))
657 loadOp,
"failed to determine memory requirements");
659 auto [memoryAccess, alignment] = *memoryRequirements;
666 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
668 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
669 if (!memrefType.getElementType().isSignlessInteger())
671 "element type is not a signless int");
673 auto loc = storeOp.getLoc();
674 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
677 adaptor.getIndices(), loc, rewriter);
681 storeOp,
"failed to convert element pointer type");
683 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
685 bool isBool = srcBits == 1;
687 srcBits = typeConverter.getOptions().boolNumBits;
692 "failed to convert memref type");
694 Type pointeeType = pointerType.getPointeeType();
696 if (typeConverter.allows(spirv::Capability::Kernel)) {
697 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
698 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
700 dstType = dyn_cast<IntegerType>(pointeeType);
703 Type structElemType =
704 cast<spirv::StructType>(pointeeType).getElementType(0);
705 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
706 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
708 dstType = dyn_cast<IntegerType>(
714 storeOp,
"failed to determine destination element type");
716 int dstBits =
static_cast<int>(dstType.getWidth());
717 assert(dstBits % srcBits == 0);
719 if (srcBits == dstBits) {
721 if (failed(memoryRequirements))
723 storeOp,
"failed to determine memory requirements");
725 auto [memoryAccess, alignment] = *memoryRequirements;
726 Value storeVal = adaptor.getValue();
730 memoryAccess, alignment);
736 if (typeConverter.allows(spirv::Capability::Kernel))
739 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
754 assert(accessChainOp.getIndices().size() == 2);
755 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
761 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
763 loc, dstType, mask, offset);
765 rewriter.
createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
767 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
769 srcBits, dstBits, rewriter);
774 Value result = rewriter.
create<spirv::AtomicAndOp>(
775 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
777 result = rewriter.
create<spirv::AtomicOrOp>(
778 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
787 assert(accessChainOp.use_empty());
788 rewriter.
eraseOp(accessChainOp);
797 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
798 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
801 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
802 if (!typeConverter.allows(spirv::Capability::Kernel))
804 loc,
"address space casts require kernel capability");
806 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
809 loc,
"SPIR-V lowering requires ranked memref types");
810 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
812 auto sourceStorageClassAttr =
813 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
814 if (!sourceStorageClassAttr)
816 diag <<
"source address space " << sourceType.getMemorySpace()
817 <<
" must be a SPIR-V storage class";
819 auto resultStorageClassAttr =
820 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
821 if (!resultStorageClassAttr)
823 diag <<
"result address space " << resultType.getMemorySpace()
824 <<
" must be a SPIR-V storage class";
827 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
828 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
830 Value result = adaptor.getSource();
831 Type resultPtrType = typeConverter.convertType(resultType);
834 "failed to convert memref type");
836 Type genericPtrType = resultPtrType;
844 if (sourceSc != spirv::StorageClass::Generic &&
845 resultSc != spirv::StorageClass::Generic) {
846 Type intermediateType =
848 sourceType.getLayout(),
849 rewriter.
getAttr<spirv::StorageClassAttr>(
850 spirv::StorageClass::Generic));
851 genericPtrType = typeConverter.convertType(intermediateType);
853 if (sourceSc != spirv::StorageClass::Generic) {
855 rewriter.
create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
857 if (resultSc != spirv::StorageClass::Generic) {
859 rewriter.
create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
866 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
868 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
869 if (memrefType.getElementType().isSignlessInteger())
872 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
873 adaptor.getIndices(), storeOp.getLoc(), rewriter);
879 if (failed(memoryRequirements))
881 storeOp,
"failed to determine memory requirements");
883 auto [memoryAccess, alignment] = *memoryRequirements;
885 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
889 LogicalResult ReinterpretCastPattern::matchAndRewrite(
890 memref::ReinterpretCastOp op, OpAdaptor adaptor,
892 Value src = adaptor.getSource();
893 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
903 if (dstType != srcType)
905 diag <<
"invalid dst type " << op.getType();
909 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
921 auto offsetValue = [&]() ->
Value {
922 if (
auto val = dyn_cast<Value>(offset))
925 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
927 return rewriter.
createOrFold<spirv::ConstantOp>(loc, intType, attr);
931 op, src, offsetValue, std::nullopt);
939 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
940 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
942 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
943 Type indexType = typeConverter.getIndexType();
945 adaptor.getSource());
957 .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959 MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960 CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961 typeConverter,
patterns.getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess