25 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "memref-to-spirv-pattern"
50 assert(targetBits % sourceBits == 0);
52 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
53 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
54 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
56 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
57 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
58 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
71 spirv::AccessChainOp op,
int sourceBits,
73 assert(targetBits % sourceBits == 0);
74 const auto loc = op.
getLoc();
77 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
78 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
79 auto indices = llvm::to_vector<4>(op.getIndices());
81 assert(indices.size() == 2);
82 indices.back() = builder.
createOrFold<spirv::SDivOp>(loc, lastDim, idx);
84 return builder.
create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = builder.
create<spirv::UConvertOp>(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
152 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153 switch (sc.getValue()) {
154 case spirv::StorageClass::StorageBuffer:
155 return spirv::Scope::Device;
156 case spirv::StorageClass::Workgroup:
157 return spirv::Scope::Workgroup;
169 auto one = spirv::ConstantOp::getOne(srcInt.
getType(), loc, builder);
170 return builder.
createOrFold<spirv::IEqualOp>(loc, srcInt, one);
189 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
202 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
207 class AtomicRMWOpPattern final
213 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
224 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
234 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
254 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
259 class MemorySpaceCastOpPattern final
265 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
275 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
279 class ReinterpretCastPattern final
285 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
294 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
296 Value src = adaptor.getSource();
301 if (srcType != dstType)
303 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
318 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
320 MemRefType allocType = allocaOp.getType();
325 Type spirvType = getTypeConverter()->convertType(allocType);
330 spirv::StorageClass::Function,
340 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
342 MemRefType allocType = operation.getType();
347 Type spirvType = getTypeConverter()->convertType(allocType);
357 spirv::GlobalVariableOp varOp;
362 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
363 std::string varName =
364 std::string(
"__workgroup_mem__") +
365 std::to_string(std::distance(varOps.begin(), varOps.end()));
366 varOp = rewriter.
create<spirv::GlobalVariableOp>(loc, spirvType, varName,
380 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
383 if (isa<FloatType>(atomicOp.getType()))
385 "unimplemented floating-point case");
387 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
391 "unsupported memref memory space");
393 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
394 Type resultType = typeConverter.convertType(atomicOp.getType());
397 "failed to convert result type");
399 auto loc = atomicOp.getLoc();
402 adaptor.getIndices(), loc, rewriter);
407 #define ATOMIC_CASE(kind, spirvOp) \
408 case arith::AtomicRMWKind::kind: \
409 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
410 atomicOp, resultType, ptr, *scope, \
411 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
414 switch (atomicOp.getKind()) {
436 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
439 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
463 memoryAccess = spirv::MemoryAccess::Nontemporal;
466 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
467 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
476 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
481 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
482 if (!sizeInBytes.has_value())
485 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
494 template <
class LoadOrStoreOp>
498 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
499 "Must be called on either memref::LoadOp or memref::StoreOp");
501 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
502 auto memrefMemAccess = memrefAccessOp->
getAttrOfType<spirv::MemoryAccessAttr>(
503 spirv::attributeName<spirv::MemoryAccess>());
504 auto memrefAlignment =
506 if (memrefMemAccess && memrefAlignment)
510 loadOrStoreOp.getNontemporal());
514 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
516 auto loc = loadOp.getLoc();
517 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
518 if (!memrefType.getElementType().isSignlessInteger())
521 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
524 adaptor.getIndices(), loc, rewriter);
529 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
530 bool isBool = srcBits == 1;
532 srcBits = typeConverter.getOptions().boolNumBits;
538 Type pointeeType = pointerType.getPointeeType();
540 if (typeConverter.allows(spirv::Capability::Kernel)) {
541 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
542 dstType = arrayType.getElementType();
544 dstType = pointeeType;
547 Type structElemType =
548 cast<spirv::StructType>(pointeeType).getElementType(0);
549 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
550 dstType = arrayType.getElementType();
552 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
555 assert(dstBits % srcBits == 0);
559 if (srcBits == dstBits) {
561 if (
failed(memoryRequirements))
563 loadOp,
"failed to determine memory requirements");
565 auto [memoryAccess, alignment] = *memoryRequirements;
566 Value loadVal = rewriter.
create<spirv::LoadOp>(loc, accessChain,
567 memoryAccess, alignment);
576 if (typeConverter.allows(spirv::Capability::Kernel))
579 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
586 assert(accessChainOp.getIndices().size() == 2);
588 srcBits, dstBits, rewriter);
590 if (
failed(memoryRequirements))
592 loadOp,
"failed to determine memory requirements");
594 auto [memoryAccess, alignment] = *memoryRequirements;
595 Value spvLoadOp = rewriter.
create<spirv::LoadOp>(loc, dstType, adjustedPtr,
596 memoryAccess, alignment);
600 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
603 loc, spvLoadOp.
getType(), spvLoadOp, offset);
607 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
609 rewriter.
createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
614 IntegerAttr shiftValueAttr =
617 rewriter.
createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
618 result = rewriter.
createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
620 result = rewriter.
createOrFold<spirv::ShiftRightArithmeticOp>(
625 assert(accessChainOp.use_empty());
626 rewriter.
eraseOp(accessChainOp);
632 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
634 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
635 if (memrefType.getElementType().isSignlessInteger())
638 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
639 adaptor.getIndices(), loadOp.getLoc(), rewriter);
645 if (
failed(memoryRequirements))
647 loadOp,
"failed to determine memory requirements");
649 auto [memoryAccess, alignment] = *memoryRequirements;
656 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
658 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
659 if (!memrefType.getElementType().isSignlessInteger())
661 "element type is not a signless int");
663 auto loc = storeOp.getLoc();
664 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
667 adaptor.getIndices(), loc, rewriter);
671 storeOp,
"failed to convert element pointer type");
673 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
675 bool isBool = srcBits == 1;
677 srcBits = typeConverter.getOptions().boolNumBits;
682 "failed to convert memref type");
684 Type pointeeType = pointerType.getPointeeType();
686 if (typeConverter.allows(spirv::Capability::Kernel)) {
687 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
688 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
690 dstType = dyn_cast<IntegerType>(pointeeType);
693 Type structElemType =
694 cast<spirv::StructType>(pointeeType).getElementType(0);
695 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
696 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
698 dstType = dyn_cast<IntegerType>(
704 storeOp,
"failed to determine destination element type");
706 int dstBits =
static_cast<int>(dstType.getWidth());
707 assert(dstBits % srcBits == 0);
709 if (srcBits == dstBits) {
711 if (
failed(memoryRequirements))
713 storeOp,
"failed to determine memory requirements");
715 auto [memoryAccess, alignment] = *memoryRequirements;
716 Value storeVal = adaptor.getValue();
720 memoryAccess, alignment);
726 if (typeConverter.allows(spirv::Capability::Kernel))
729 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
744 assert(accessChainOp.getIndices().size() == 2);
745 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
751 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
753 loc, dstType, mask, offset);
755 rewriter.
createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
757 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
759 srcBits, dstBits, rewriter);
764 Value result = rewriter.
create<spirv::AtomicAndOp>(
765 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
767 result = rewriter.
create<spirv::AtomicOrOp>(
768 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
777 assert(accessChainOp.use_empty());
778 rewriter.
eraseOp(accessChainOp);
788 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
791 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
792 if (!typeConverter.allows(spirv::Capability::Kernel))
794 loc,
"address space casts require kernel capability");
796 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
799 loc,
"SPIR-V lowering requires ranked memref types");
800 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
802 auto sourceStorageClassAttr =
803 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
804 if (!sourceStorageClassAttr)
806 diag <<
"source address space " << sourceType.getMemorySpace()
807 <<
" must be a SPIR-V storage class";
809 auto resultStorageClassAttr =
810 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
811 if (!resultStorageClassAttr)
813 diag <<
"result address space " << resultType.getMemorySpace()
814 <<
" must be a SPIR-V storage class";
817 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
818 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
820 Value result = adaptor.getSource();
821 Type resultPtrType = typeConverter.convertType(resultType);
824 "failed to convert memref type");
826 Type genericPtrType = resultPtrType;
834 if (sourceSc != spirv::StorageClass::Generic &&
835 resultSc != spirv::StorageClass::Generic) {
836 Type intermediateType =
838 sourceType.getLayout(),
839 rewriter.
getAttr<spirv::StorageClassAttr>(
840 spirv::StorageClass::Generic));
841 genericPtrType = typeConverter.convertType(intermediateType);
843 if (sourceSc != spirv::StorageClass::Generic) {
845 rewriter.
create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
847 if (resultSc != spirv::StorageClass::Generic) {
849 rewriter.
create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
856 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
858 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
859 if (memrefType.getElementType().isSignlessInteger())
862 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
863 adaptor.getIndices(), storeOp.getLoc(), rewriter);
869 if (
failed(memoryRequirements))
871 storeOp,
"failed to determine memory requirements");
873 auto [memoryAccess, alignment] = *memoryRequirements;
875 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
880 memref::ReinterpretCastOp op, OpAdaptor adaptor,
882 Value src = adaptor.getSource();
883 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
893 if (dstType != srcType)
895 diag <<
"invalid dst type " << op.getType();
899 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
911 auto offsetValue = [&]() ->
Value {
912 if (
auto val = dyn_cast<Value>(offset))
915 int64_t attrVal = cast<IntegerAttr>(offset.get<
Attribute>()).getInt();
917 return rewriter.
createOrFold<spirv::ConstantOp>(loc, intType, attr);
921 op, src, offsetValue, std::nullopt);
932 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
933 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
934 LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
935 ReinterpretCastPattern, CastPattern>(typeConverter,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
spirv::MemoryAccessAttr memoryAccess
This class represents an efficient way to signal success or failure.