20 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "memref-to-spirv-pattern"
44 assert(targetBits % sourceBits == 0);
46 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
47 auto idx = builder.
create<spirv::ConstantOp>(loc, type, idxAttr);
48 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
49 auto srcBitsValue = builder.
create<spirv::ConstantOp>(loc, type, srcBitsAttr);
50 auto m = builder.
create<spirv::UModOp>(loc, srcIdx, idx);
51 return builder.
create<spirv::IMulOp>(loc, type, m, srcBitsValue);
64 spirv::AccessChainOp op,
int sourceBits,
66 assert(targetBits % sourceBits == 0);
67 const auto loc = op.
getLoc();
70 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
71 auto idx = builder.
create<spirv::ConstantOp>(loc, type, attr);
72 auto indices = llvm::to_vector<4>(op.getIndices());
74 assert(indices.size() == 2);
75 indices.back() = builder.
create<spirv::SDivOp>(loc, lastDim, idx);
77 return builder.
create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
87 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
88 return builder.
create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
95 IntegerType dstType = cast<IntegerType>(mask.
getType());
96 int targetBits =
static_cast<int>(dstType.getWidth());
98 assert(valueBits <= targetBits);
100 if (valueBits == 1) {
103 if (valueBits < targetBits) {
104 value = builder.
create<spirv::UConvertOp>(
108 value = builder.
create<spirv::BitwiseAndOp>(loc, value, mask);
110 return builder.
create<spirv::ShiftLeftLogicalOp>(loc, value.
getType(), value,
117 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
118 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
119 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
121 }
else if (isa<memref::AllocaOp>(allocOp)) {
122 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
123 if (!sc || sc.getValue() != spirv::StorageClass::Function)
131 if (!type.hasStaticShape())
134 Type elementType = type.getElementType();
135 if (
auto vecType = dyn_cast<VectorType>(elementType))
136 elementType = vecType.getElementType();
144 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
145 switch (sc.getValue()) {
146 case spirv::StorageClass::StorageBuffer:
147 return spirv::Scope::Device;
148 case spirv::StorageClass::Workgroup:
149 return spirv::Scope::Workgroup;
161 auto one = spirv::ConstantOp::getOne(srcInt.
getType(), loc, builder);
162 return builder.
create<spirv::IEqualOp>(loc, srcInt, one);
181 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
194 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
199 class AtomicRMWOpPattern final
205 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
216 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
226 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
236 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
246 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
251 class MemorySpaceCastOpPattern final
257 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
267 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
271 class ReinterpretCastPattern final
277 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
286 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
288 Value src = adaptor.getSource();
293 if (srcType != dstType)
295 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
310 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
312 MemRefType allocType = allocaOp.getType();
317 Type spirvType = getTypeConverter()->convertType(allocType);
322 spirv::StorageClass::Function,
332 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
334 MemRefType allocType = operation.getType();
339 Type spirvType = getTypeConverter()->convertType(allocType);
349 spirv::GlobalVariableOp varOp;
354 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
355 std::string varName =
356 std::string(
"__workgroup_mem__") +
357 std::to_string(std::distance(varOps.begin(), varOps.end()));
358 varOp = rewriter.
create<spirv::GlobalVariableOp>(loc, spirvType, varName,
372 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
375 if (isa<FloatType>(atomicOp.getType()))
377 "unimplemented floating-point case");
379 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
383 "unsupported memref memory space");
385 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
386 Type resultType = typeConverter.convertType(atomicOp.getType());
389 "failed to convert result type");
391 auto loc = atomicOp.getLoc();
394 adaptor.getIndices(), loc, rewriter);
399 #define ATOMIC_CASE(kind, spirvOp) \
400 case arith::AtomicRMWKind::kind: \
401 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
402 atomicOp, resultType, ptr, *scope, \
403 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
406 switch (atomicOp.getKind()) {
428 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
431 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
443 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
445 auto loc = loadOp.getLoc();
446 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
447 if (!memrefType.getElementType().isSignlessInteger())
450 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
453 adaptor.getIndices(), loc, rewriter);
458 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
459 bool isBool = srcBits == 1;
461 srcBits = typeConverter.getOptions().boolNumBits;
467 Type pointeeType = pointerType.getPointeeType();
469 if (typeConverter.allows(spirv::Capability::Kernel)) {
470 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
471 dstType = arrayType.getElementType();
473 dstType = pointeeType;
476 Type structElemType =
477 cast<spirv::StructType>(pointeeType).getElementType(0);
478 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
479 dstType = arrayType.getElementType();
481 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
484 assert(dstBits % srcBits == 0);
488 if (srcBits == dstBits) {
489 Value loadVal = rewriter.
create<spirv::LoadOp>(loc, accessChain);
498 if (typeConverter.allows(spirv::Capability::Kernel))
501 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
508 assert(accessChainOp.getIndices().size() == 2);
510 srcBits, dstBits, rewriter);
512 loc, dstType, adjustedPtr,
513 loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
514 spirv::attributeName<spirv::MemoryAccess>()),
515 loadOp->getAttrOfType<IntegerAttr>(
"alignment"));
519 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
521 Value result = rewriter.
create<spirv::ShiftRightArithmeticOp>(
522 loc, spvLoadOp.
getType(), spvLoadOp, offset);
526 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
527 result = rewriter.
create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
532 IntegerAttr shiftValueAttr =
535 rewriter.
create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
536 result = rewriter.
create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
538 result = rewriter.
create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
543 assert(accessChainOp.use_empty());
544 rewriter.
eraseOp(accessChainOp);
550 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
552 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
553 if (memrefType.getElementType().isSignlessInteger())
556 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
557 adaptor.getIndices(), loadOp.getLoc(), rewriter);
567 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
569 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
570 if (!memrefType.getElementType().isSignlessInteger())
572 "element type is not a signless int");
574 auto loc = storeOp.getLoc();
575 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
578 adaptor.getIndices(), loc, rewriter);
582 storeOp,
"failed to convert element pointer type");
584 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
586 bool isBool = srcBits == 1;
588 srcBits = typeConverter.getOptions().boolNumBits;
593 "failed to convert memref type");
595 Type pointeeType = pointerType.getPointeeType();
597 if (typeConverter.allows(spirv::Capability::Kernel)) {
598 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
599 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
601 dstType = dyn_cast<IntegerType>(pointeeType);
604 Type structElemType =
605 cast<spirv::StructType>(pointeeType).getElementType(0);
606 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
607 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
609 dstType = dyn_cast<IntegerType>(
615 storeOp,
"failed to determine destination element type");
617 int dstBits =
static_cast<int>(dstType.getWidth());
618 assert(dstBits % srcBits == 0);
620 if (srcBits == dstBits) {
621 Value storeVal = adaptor.getValue();
630 if (typeConverter.allows(spirv::Capability::Kernel))
633 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
648 assert(accessChainOp.getIndices().size() == 2);
649 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
655 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
656 Value clearBitsMask =
657 rewriter.
create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
658 clearBitsMask = rewriter.
create<spirv::NotOp>(loc, dstType, clearBitsMask);
660 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
662 srcBits, dstBits, rewriter);
667 Value result = rewriter.
create<spirv::AtomicAndOp>(
668 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
670 result = rewriter.
create<spirv::AtomicOrOp>(
671 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
680 assert(accessChainOp.use_empty());
681 rewriter.
eraseOp(accessChainOp);
691 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
694 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
695 if (!typeConverter.allows(spirv::Capability::Kernel))
697 loc,
"address space casts require kernel capability");
699 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
702 loc,
"SPIR-V lowering requires ranked memref types");
703 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
705 auto sourceStorageClassAttr =
706 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
707 if (!sourceStorageClassAttr)
709 diag <<
"source address space " << sourceType.getMemorySpace()
710 <<
" must be a SPIR-V storage class";
712 auto resultStorageClassAttr =
713 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
714 if (!resultStorageClassAttr)
716 diag <<
"result address space " << resultType.getMemorySpace()
717 <<
" must be a SPIR-V storage class";
720 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
721 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
723 Value result = adaptor.getSource();
724 Type resultPtrType = typeConverter.convertType(resultType);
727 "failed to convert memref type");
729 Type genericPtrType = resultPtrType;
737 if (sourceSc != spirv::StorageClass::Generic &&
738 resultSc != spirv::StorageClass::Generic) {
739 Type intermediateType =
741 sourceType.getLayout(),
742 rewriter.
getAttr<spirv::StorageClassAttr>(
743 spirv::StorageClass::Generic));
744 genericPtrType = typeConverter.convertType(intermediateType);
746 if (sourceSc != spirv::StorageClass::Generic) {
748 rewriter.
create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
750 if (resultSc != spirv::StorageClass::Generic) {
752 rewriter.
create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
759 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
761 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
762 if (memrefType.getElementType().isSignlessInteger())
765 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
766 adaptor.getIndices(), storeOp.getLoc(), rewriter);
777 memref::ReinterpretCastOp op, OpAdaptor adaptor,
779 Value src = adaptor.getSource();
780 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
790 if (dstType != srcType)
792 diag <<
"invalid dst type " << op.getType();
796 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
808 auto offsetValue = [&]() ->
Value {
809 if (
auto val = dyn_cast<Value>(offset))
812 int64_t attrVal = cast<IntegerAttr>(offset.get<
Attribute>()).getInt();
814 return rewriter.
create<spirv::ConstantOp>(loc, intType, attr);
818 op, src, offsetValue, std::nullopt);
829 patterns.
add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
830 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
831 LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
832 ReinterpretCastPattern, CastPattern>(typeConverter,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
This header declares functions that assist transformations in the MemRef dialect.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
This class represents an efficient way to signal success or failure.