28 #define DEBUG_TYPE "memref-to-spirv-pattern"
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.
getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.
getIntegerAttr(type, sourceBits);
55 builder.
createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.
createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.
createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
70 spirv::AccessChainOp op,
int sourceBits,
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.
getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.
createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
80 assert(indices.size() == 2);
81 indices.back() = builder.
createOrFold<spirv::SDivOp>(loc, lastDim, idx);
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.
createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
103 IntegerType dstType = cast<IntegerType>(mask.
getType());
104 int targetBits =
static_cast<int>(dstType.getWidth());
106 assert(valueBits <= targetBits);
108 if (valueBits == 1) {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
116 value = builder.
createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
129 }
else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
139 if (!type.hasStaticShape())
142 Type elementType = type.getElementType();
143 if (
auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
152 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153 switch (sc.getValue()) {
154 case spirv::StorageClass::StorageBuffer:
155 return spirv::Scope::Device;
156 case spirv::StorageClass::Workgroup:
157 return spirv::Scope::Workgroup;
170 return builder.
createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
189 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
202 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
207 class AtomicRMWOpPattern final
213 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
224 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
234 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
254 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
264 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
269 class MemorySpaceCastOpPattern final
275 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
285 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
289 class ReinterpretCastPattern final
295 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
304 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
306 Value src = adaptor.getSource();
311 if (srcType != dstType)
313 diag <<
"types doesn't match: " << srcType <<
" and " << dstType;
322 class ExtractAlignedPointerAsIndexOpPattern final
328 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
339 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
341 MemRefType allocType = allocaOp.getType();
346 Type spirvType = getTypeConverter()->convertType(allocType);
351 spirv::StorageClass::Function,
361 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
363 MemRefType allocType = operation.getType();
368 Type spirvType = getTypeConverter()->convertType(allocType);
378 spirv::GlobalVariableOp varOp;
383 auto varOps = entryBlock.
getOps<spirv::GlobalVariableOp>();
384 std::string varName =
385 std::string(
"__workgroup_mem__") +
386 std::to_string(std::distance(varOps.begin(), varOps.end()));
387 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
401 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
404 if (isa<FloatType>(atomicOp.getType()))
406 "unimplemented floating-point case");
408 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
412 "unsupported memref memory space");
414 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
415 Type resultType = typeConverter.convertType(atomicOp.getType());
418 "failed to convert result type");
420 auto loc = atomicOp.getLoc();
423 adaptor.getIndices(), loc, rewriter);
428 #define ATOMIC_CASE(kind, spirvOp) \
429 case arith::AtomicRMWKind::kind: \
430 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
431 atomicOp, resultType, ptr, *scope, \
432 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
435 switch (atomicOp.getKind()) {
457 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
460 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
478 static FailureOr<MemoryRequirements>
480 uint64_t preferredAlignment) {
489 memoryAccess = spirv::MemoryAccess::Nontemporal;
492 auto ptrType = cast<spirv::PointerType>(accessedPtr.
getType());
493 bool mayOmitAlignment =
494 !preferredAlignment &&
495 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
496 if (mayOmitAlignment) {
506 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
511 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
512 if (!sizeInBytes.has_value())
515 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
517 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
525 template <
class LoadOrStoreOp>
526 static FailureOr<MemoryRequirements>
529 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
530 "Must be called on either memref::LoadOp or memref::StoreOp");
533 loadOrStoreOp.getNontemporal(),
534 loadOrStoreOp.getAlignment().value_or(0));
538 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
540 auto loc = loadOp.getLoc();
541 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
542 if (!memrefType.getElementType().isSignlessInteger())
545 auto memorySpaceAttr =
546 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
547 if (!memorySpaceAttr)
549 loadOp,
"missing memory space SPIR-V storage class attribute");
551 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
554 "failed to lower memref in image storage class to storage buffer");
556 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
559 adaptor.getIndices(), loc, rewriter);
564 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
565 bool isBool = srcBits == 1;
567 srcBits = typeConverter.getOptions().boolNumBits;
573 Type pointeeType = pointerType.getPointeeType();
575 if (typeConverter.allows(spirv::Capability::Kernel)) {
576 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
577 dstType = arrayType.getElementType();
579 dstType = pointeeType;
582 Type structElemType =
583 cast<spirv::StructType>(pointeeType).getElementType(0);
584 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
585 dstType = arrayType.getElementType();
587 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
590 assert(dstBits % srcBits == 0);
594 if (srcBits == dstBits) {
596 if (
failed(memoryRequirements))
598 loadOp,
"failed to determine memory requirements");
600 auto [memoryAccess, alignment] = *memoryRequirements;
601 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
602 memoryAccess, alignment);
611 if (typeConverter.allows(spirv::Capability::Kernel))
614 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
621 assert(accessChainOp.getIndices().size() == 2);
623 srcBits, dstBits, rewriter);
625 if (
failed(memoryRequirements))
627 loadOp,
"failed to determine memory requirements");
629 auto [memoryAccess, alignment] = *memoryRequirements;
630 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
631 memoryAccess, alignment);
635 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
638 loc, spvLoadOp.
getType(), spvLoadOp, offset);
642 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
644 rewriter.
createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
649 IntegerAttr shiftValueAttr =
652 rewriter.
createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
653 result = rewriter.
createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
655 result = rewriter.
createOrFold<spirv::ShiftRightArithmeticOp>(
660 assert(accessChainOp.use_empty());
661 rewriter.
eraseOp(accessChainOp);
667 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
669 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
670 if (memrefType.getElementType().isSignlessInteger())
673 auto memorySpaceAttr =
674 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
675 if (!memorySpaceAttr)
677 loadOp,
"missing memory space SPIR-V storage class attribute");
679 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
682 "failed to lower memref in image storage class to storage buffer");
685 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
686 adaptor.getIndices(), loadOp.getLoc(), rewriter);
692 if (
failed(memoryRequirements))
694 loadOp,
"failed to determine memory requirements");
696 auto [memoryAccess, alignment] = *memoryRequirements;
703 ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
705 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
707 auto memorySpaceAttr =
708 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
709 if (!memorySpaceAttr)
711 loadOp,
"missing memory space SPIR-V storage class attribute");
713 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
715 loadOp,
"failed to lower memref in non-image storage class to image");
717 Value loadPtr = adaptor.getMemref();
719 if (
failed(memoryRequirements))
721 loadOp,
"failed to determine memory requirements");
723 const auto [memoryAccess, alignment] = *memoryRequirements;
725 if (!loadOp.getMemRefType().hasRank())
727 loadOp,
"cannot lower unranked memrefs to SPIR-V images");
732 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
735 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
742 auto convertedPointeeType = cast<spirv::PointerType>(
743 getTypeConverter()->convertType(loadOp.getMemRefType()));
744 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
746 "cannot lower memrefs which do not "
747 "convert to SPIR-V sampled images");
752 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
754 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
758 if (memrefType.getRank() != 1) {
759 auto coordVectorType =
VectorType::get({loadOp.getMemRefType().getRank()},
760 adaptor.getIndices().
getType()[0]);
761 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
762 adaptor.getIndices());
764 coords = adaptor.getIndices()[0];
769 auto fetchOp = spirv::ImageFetchOp::create(
770 rewriter, loc, resultVectorType, imageOp, coords,
771 mlir::spirv::ImageOperandsAttr{},
ValueRange{});
776 auto compositeExtractOp =
777 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
779 rewriter.
replaceOp(loadOp, compositeExtractOp);
784 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
786 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
787 if (!memrefType.getElementType().isSignlessInteger())
789 "element type is not a signless int");
791 auto loc = storeOp.getLoc();
792 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
795 adaptor.getIndices(), loc, rewriter);
799 storeOp,
"failed to convert element pointer type");
801 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
803 bool isBool = srcBits == 1;
805 srcBits = typeConverter.getOptions().boolNumBits;
810 "failed to convert memref type");
812 Type pointeeType = pointerType.getPointeeType();
814 if (typeConverter.allows(spirv::Capability::Kernel)) {
815 if (
auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
816 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
818 dstType = dyn_cast<IntegerType>(pointeeType);
821 Type structElemType =
822 cast<spirv::StructType>(pointeeType).getElementType(0);
823 if (
auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
824 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
826 dstType = dyn_cast<IntegerType>(
832 storeOp,
"failed to determine destination element type");
834 int dstBits =
static_cast<int>(dstType.getWidth());
835 assert(dstBits % srcBits == 0);
837 if (srcBits == dstBits) {
839 if (
failed(memoryRequirements))
841 storeOp,
"failed to determine memory requirements");
843 auto [memoryAccess, alignment] = *memoryRequirements;
844 Value storeVal = adaptor.getValue();
848 memoryAccess, alignment);
854 if (typeConverter.allows(spirv::Capability::Kernel))
857 auto accessChainOp = accessChain.
getDefiningOp<spirv::AccessChainOp>();
872 assert(accessChainOp.getIndices().size() == 2);
873 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
879 loc, dstType, rewriter.
getIntegerAttr(dstType, (1 << srcBits) - 1));
881 loc, dstType, mask, offset);
883 rewriter.
createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
885 Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
887 srcBits, dstBits, rewriter);
892 Value result = spirv::AtomicAndOp::create(
893 rewriter, loc, dstType, adjustedPtr, *scope,
894 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
895 result = spirv::AtomicOrOp::create(
896 rewriter, loc, dstType, adjustedPtr, *scope,
897 spirv::MemorySemantics::AcquireRelease, storeVal);
905 assert(accessChainOp.use_empty());
906 rewriter.
eraseOp(accessChainOp);
915 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
916 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
919 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
920 if (!typeConverter.allows(spirv::Capability::Kernel))
922 loc,
"address space casts require kernel capability");
924 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
927 loc,
"SPIR-V lowering requires ranked memref types");
928 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
930 auto sourceStorageClassAttr =
931 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
932 if (!sourceStorageClassAttr)
934 diag <<
"source address space " << sourceType.getMemorySpace()
935 <<
" must be a SPIR-V storage class";
937 auto resultStorageClassAttr =
938 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
939 if (!resultStorageClassAttr)
941 diag <<
"result address space " << resultType.getMemorySpace()
942 <<
" must be a SPIR-V storage class";
945 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
946 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
948 Value result = adaptor.getSource();
949 Type resultPtrType = typeConverter.convertType(resultType);
952 "failed to convert memref type");
954 Type genericPtrType = resultPtrType;
962 if (sourceSc != spirv::StorageClass::Generic &&
963 resultSc != spirv::StorageClass::Generic) {
964 Type intermediateType =
966 sourceType.getLayout(),
967 rewriter.
getAttr<spirv::StorageClassAttr>(
968 spirv::StorageClass::Generic));
969 genericPtrType = typeConverter.convertType(intermediateType);
971 if (sourceSc != spirv::StorageClass::Generic) {
972 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
975 if (resultSc != spirv::StorageClass::Generic) {
977 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
984 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
986 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
987 if (memrefType.getElementType().isSignlessInteger())
990 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
991 adaptor.getIndices(), storeOp.getLoc(), rewriter);
997 if (
failed(memoryRequirements))
999 storeOp,
"failed to determine memory requirements");
1001 auto [memoryAccess, alignment] = *memoryRequirements;
1003 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1007 LogicalResult ReinterpretCastPattern::matchAndRewrite(
1008 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1010 Value src = adaptor.getSource();
1011 auto srcType = dyn_cast<spirv::PointerType>(src.
getType());
1021 if (dstType != srcType)
1023 diag <<
"invalid dst type " << op.getType();
1027 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1039 auto offsetValue = [&]() ->
Value {
1040 if (
auto val = dyn_cast<Value>(offset))
1043 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1045 return rewriter.
createOrFold<spirv::ConstantOp>(loc, intType, attr);
1057 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1058 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1060 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1061 Type indexType = typeConverter.getIndexType();
1063 adaptor.getSource());
1074 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1075 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1076 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1077 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1078 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Type getElementType(Type type)
Determine the element type of type.
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess