24 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.
getLoc();
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
80 assert(indices.size() == 2);
81 indices.back() = builder.
createOrFold<spirv::SDivOp>(loc, lastDim, idx);
83 return builder.
create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
102 IntegerType dstType = cast<IntegerType>(mask.
getType());
103 int targetBits =
static_cast<int>(dstType.getWidth());
105 assert(valueBits <= targetBits);
107 if (valueBits == 1) {
110 if (valueBits < targetBits) {
111 value = builder.
create<spirv::UConvertOp>(
115 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 }
else if (isa<memref::AllocaOp>(allocOp)) {
129 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130 if (!sc || sc.getValue() != spirv::StorageClass::Function)
138 if (!type.hasStaticShape())
141 Type elementType = type.getElementType();
142 if (
auto vecType = dyn_cast<VectorType>(elementType))
143 elementType = vecType.getElementType();
151 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152 switch (sc.getValue()) {
153 case spirv::StorageClass::StorageBuffer:
154 return spirv::Scope::Device;
155 case spirv::StorageClass::Workgroup:
156 return spirv::Scope::Workgroup;
168 auto one = spirv::ConstantOp::getOne(srcInt.
getType(), loc, builder);
169 return builder.
createOrFold<spirv::IEqualOp>(loc, srcInt, one);
188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
206 class AtomicRMWOpPattern final
212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
258 class MemorySpaceCastOpPattern final
264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
278 class ReinterpretCastPattern final
284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
295 Value src = adaptor.getSource();
300 if (srcType != dstType)
302 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
317 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
319 MemRefType allocType = allocaOp.getType();
324 Type spirvType = getTypeConverter()->convertType(allocType);
329 spirv::StorageClass::Function,
339 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
341 MemRefType allocType = operation.getType();
346 Type spirvType = getTypeConverter()->convertType(allocType);
356 spirv::GlobalVariableOp varOp;
361 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
362 std::string varName =
363 std::string(
"__workgroup_mem__") +
364 std::to_string(std::distance(varOps.begin(), varOps.end()));
365 varOp = rewriter.
create<spirv::GlobalVariableOp>(loc, spirvType, varName,
379 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
382 if (isa<FloatType>(atomicOp.getType()))
384 "unimplemented floating-point case");
386 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
390 "unsupported memref memory space");
392 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
393 Type resultType = typeConverter.convertType(atomicOp.getType());
396 "failed to convert result type");
398 auto loc = atomicOp.getLoc();
401 adaptor.getIndices(), loc, rewriter);
406 #define ATOMIC_CASE(kind, spirvOp) \
407 case arith::AtomicRMWKind::kind: \
408 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
409 atomicOp, resultType, ptr, *scope, \
410 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
413 switch (atomicOp.getKind()) {
435 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
438 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
456 static FailureOr<MemoryRequirements>
462 memoryAccess = spirv::MemoryAccess::Nontemporal;
465 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
466 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
475 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
480 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
481 if (!sizeInBytes.has_value())
484 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
493 template <
class LoadOrStoreOp>
494 static FailureOr<MemoryRequirements>
497 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
498 "Must be called on either memref::LoadOp or memref::StoreOp");
500 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
501 auto memrefMemAccess = memrefAccessOp->
getAttrOfType<spirv::MemoryAccessAttr>(
502 spirv::attributeName<spirv::MemoryAccess>());
503 auto memrefAlignment =
505 if (memrefMemAccess && memrefAlignment)
509 loadOrStoreOp.getNontemporal());
513 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
515 auto loc = loadOp.getLoc();
516 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
517 if (!memrefType.getElementType().isSignlessInteger())
520 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
523 adaptor.getIndices(), loc, rewriter);
528 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
529 bool isBool = srcBits == 1;
531 srcBits = typeConverter.getOptions().boolNumBits;
537 Type pointeeType = pointerType.getPointeeType();
539 if (typeConverter.allows(spirv::Capability::Kernel)) {
540 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
541 dstType = arrayType.getElementType();
543 dstType = pointeeType;
546 Type structElemType =
547 cast<spirv::StructType>(pointeeType).getElementType(0);
548 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
549 dstType = arrayType.getElementType();
551 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
554 assert(dstBits % srcBits == 0);
558 if (srcBits == dstBits) {
560 if (failed(memoryRequirements))
562 loadOp,
"failed to determine memory requirements");
564 auto [memoryAccess, alignment] = *memoryRequirements;
565 Value loadVal = rewriter.
create<spirv::LoadOp>(loc, accessChain,
566 memoryAccess, alignment);
575 if (typeConverter.allows(spirv::Capability::Kernel))
578 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
585 assert(accessChainOp.getIndices().size() == 2);
587 srcBits, dstBits, rewriter);
589 if (failed(memoryRequirements))
591 loadOp,
"failed to determine memory requirements");
593 auto [memoryAccess, alignment] = *memoryRequirements;
594 Value spvLoadOp = rewriter.
create<spirv::LoadOp>(loc, dstType, adjustedPtr,
595 memoryAccess, alignment);
599 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
602 loc, spvLoadOp.
getType(), spvLoadOp, offset);
606 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
608 rewriter.
createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
613 IntegerAttr shiftValueAttr =
616 rewriter.
createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
617 result = rewriter.
createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
619 result = rewriter.
createOrFold<spirv::ShiftRightArithmeticOp>(
624 assert(accessChainOp.use_empty());
625 rewriter.
eraseOp(accessChainOp);
631 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
633 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
634 if (memrefType.getElementType().isSignlessInteger())
637 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
638 adaptor.getIndices(), loadOp.getLoc(), rewriter);
644 if (failed(memoryRequirements))
646 loadOp,
"failed to determine memory requirements");
648 auto [memoryAccess, alignment] = *memoryRequirements;
655 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
657 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
658 if (!memrefType.getElementType().isSignlessInteger())
660 "element type is not a signless int");
662 auto loc = storeOp.getLoc();
663 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
666 adaptor.getIndices(), loc, rewriter);
670 storeOp,
"failed to convert element pointer type");
672 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
674 bool isBool = srcBits == 1;
676 srcBits = typeConverter.getOptions().boolNumBits;
681 "failed to convert memref type");
683 Type pointeeType = pointerType.getPointeeType();
685 if (typeConverter.allows(spirv::Capability::Kernel)) {
686 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
687 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
689 dstType = dyn_cast<IntegerType>(pointeeType);
692 Type structElemType =
693 cast<spirv::StructType>(pointeeType).getElementType(0);
694 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
695 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
697 dstType = dyn_cast<IntegerType>(
703 storeOp,
"failed to determine destination element type");
705 int dstBits =
static_cast<int>(dstType.getWidth());
706 assert(dstBits % srcBits == 0);
708 if (srcBits == dstBits) {
710 if (failed(memoryRequirements))
712 storeOp,
"failed to determine memory requirements");
714 auto [memoryAccess, alignment] = *memoryRequirements;
715 Value storeVal = adaptor.getValue();
719 memoryAccess, alignment);
725 if (typeConverter.allows(spirv::Capability::Kernel))
728 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
743 assert(accessChainOp.getIndices().size() == 2);
744 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
750 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
752 loc, dstType, mask, offset);
754 rewriter.
createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
756 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
758 srcBits, dstBits, rewriter);
763 Value result = rewriter.
create<spirv::AtomicAndOp>(
764 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
766 result = rewriter.
create<spirv::AtomicOrOp>(
767 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
776 assert(accessChainOp.use_empty());
777 rewriter.
eraseOp(accessChainOp);
786 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
787 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
790 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
791 if (!typeConverter.allows(spirv::Capability::Kernel))
793 loc,
"address space casts require kernel capability");
795 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
798 loc,
"SPIR-V lowering requires ranked memref types");
799 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
801 auto sourceStorageClassAttr =
802 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
803 if (!sourceStorageClassAttr)
805 diag <<
"source address space " << sourceType.getMemorySpace()
806 <<
" must be a SPIR-V storage class";
808 auto resultStorageClassAttr =
809 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
810 if (!resultStorageClassAttr)
812 diag <<
"result address space " << resultType.getMemorySpace()
813 <<
" must be a SPIR-V storage class";
816 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
817 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
819 Value result = adaptor.getSource();
820 Type resultPtrType = typeConverter.convertType(resultType);
823 "failed to convert memref type");
825 Type genericPtrType = resultPtrType;
833 if (sourceSc != spirv::StorageClass::Generic &&
834 resultSc != spirv::StorageClass::Generic) {
835 Type intermediateType =
837 sourceType.getLayout(),
838 rewriter.
getAttr<spirv::StorageClassAttr>(
839 spirv::StorageClass::Generic));
840 genericPtrType = typeConverter.convertType(intermediateType);
842 if (sourceSc != spirv::StorageClass::Generic) {
844 rewriter.
create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
846 if (resultSc != spirv::StorageClass::Generic) {
848 rewriter.
create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
855 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
857 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
858 if (memrefType.getElementType().isSignlessInteger())
861 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
862 adaptor.getIndices(), storeOp.getLoc(), rewriter);
868 if (failed(memoryRequirements))
870 storeOp,
"failed to determine memory requirements");
872 auto [memoryAccess, alignment] = *memoryRequirements;
874 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
878 LogicalResult ReinterpretCastPattern::matchAndRewrite(
879 memref::ReinterpretCastOp op, OpAdaptor adaptor,
881 Value src = adaptor.getSource();
882 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
892 if (dstType != srcType)
894 diag <<
"invalid dst type " << op.getType();
898 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
910 auto offsetValue = [&]() ->
Value {
911 if (
auto val = dyn_cast<Value>(offset))
914 int64_t attrVal = cast<IntegerAttr>(offset.get<
Attribute>()).getInt();
916 return rewriter.
createOrFold<spirv::ConstantOp>(loc, intType, attr);
920 op, src, offsetValue, std::nullopt);
931 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933 LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934 ReinterpretCastPattern, CastPattern>(typeConverter,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess