29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
44 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
52 if (
auto vectorType = dyn_cast<VectorType>(type))
53 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
63 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
65 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
71 if (dstType == adaptor.getSource().getType() ||
72 shapeCastOp.getResultVectorType().getNumElements() == 1) {
73 rewriter.
replaceOp(shapeCastOp, adaptor.getSource());
84 struct VectorSplatToBroadcast final
88 matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
96 struct VectorBitcastConvert final
101 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
103 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
107 if (dstType == adaptor.getSource().getType()) {
108 rewriter.
replaceOp(bitcastOp, adaptor.getSource());
115 Type srcType = adaptor.getSource().getType();
119 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
124 adaptor.getSource());
129 struct VectorBroadcastConvert final
134 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
137 getTypeConverter()->convertType(castOp.getResultVectorType());
141 if (isa<spirv::ScalarType>(resultType)) {
142 rewriter.
replaceOp(castOp, adaptor.getSource());
147 adaptor.getSource());
162 int64_t kPoisonIndex,
unsigned vectorSize) {
163 if (llvm::isPowerOf2_32(vectorSize)) {
164 Value inBoundsMask = spirv::ConstantOp::create(
165 rewriter, loc, dynamicIndex.
getType(),
167 return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
170 Value poisonIndex = spirv::ConstantOp::create(
171 rewriter, loc, dynamicIndex.
getType(),
174 spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
175 return spirv::SelectOp::create(
176 rewriter, loc, cmpResult,
181 struct VectorExtractOpConvert final
186 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
188 Type dstType = getTypeConverter()->convertType(extractOp.getType());
192 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
193 rewriter.
replaceOp(extractOp, adaptor.getVector());
197 if (std::optional<int64_t>
id =
199 if (
id == vector::ExtractOp::kPoisonIndex)
202 "Static use of poison index handled elsewhere (folded to poison)");
204 extractOp, dstType, adaptor.getVector(),
207 Value sanitizedIndex = sanitizeDynamicIndex(
208 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
209 vector::ExtractOp::kPoisonIndex,
210 extractOp.getSourceVectorType().getNumElements());
212 extractOp, dstType, adaptor.getVector(), sanitizedIndex);
218 struct VectorExtractStridedSliceOpConvert final
223 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
225 Type dstType = getTypeConverter()->convertType(extractOp.getType());
235 Value srcVector = adaptor.getOperands().front();
238 if (isa<spirv::ScalarType>(dstType)) {
245 std::iota(indices.begin(), indices.end(), offset);
248 extractOp, dstType, srcVector, srcVector,
255 template <
class SPIRVFMAOp>
260 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
262 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
266 adaptor.getRhs(), adaptor.getAcc());
271 struct VectorFromElementsOpConvert final
276 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
278 Type resultType = getTypeConverter()->convertType(op.getType());
282 if (isa<spirv::ScalarType>(resultType)) {
290 assert(cast<VectorType>(resultType).getRank() == 1);
297 struct VectorInsertOpConvert final
302 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
304 if (isa<VectorType>(insertOp.getValueToStoreType()))
306 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
308 "unsupported dest vector type");
311 if (insertOp.getValueToStoreType().isIntOrFloat() &&
312 insertOp.getDestVectorType().getNumElements() == 1) {
313 rewriter.
replaceOp(insertOp, adaptor.getValueToStore());
317 if (std::optional<int64_t>
id =
319 if (
id == vector::InsertOp::kPoisonIndex)
322 "Static use of poison index handled elsewhere (folded to poison)");
324 insertOp, adaptor.getValueToStore(), adaptor.getDest(),
id.value());
326 Value sanitizedIndex = sanitizeDynamicIndex(
327 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
328 vector::InsertOp::kPoisonIndex,
329 insertOp.getDestVectorType().getNumElements());
331 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
338 struct VectorInsertStridedSliceOpConvert final
343 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
345 Value srcVector = adaptor.getOperands().front();
346 Value dstVector = adaptor.getOperands().back();
353 if (isa<spirv::ScalarType>(srcVector.
getType())) {
354 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
356 insertOp, dstVector.
getType(), srcVector, dstVector,
361 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
362 uint64_t insertSize =
363 cast<VectorType>(srcVector.
getType()).getNumElements();
366 std::iota(indices.begin(), indices.end(), 0);
367 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
371 insertOp, dstVector.getType(), dstVector, srcVector,
379 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
381 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
383 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
386 for (
int i = 0; i < numElements; ++i) {
387 values.push_back(spirv::CompositeExtractOp::create(
388 rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
391 if (
Value acc = adaptor.getAcc())
392 values.push_back(acc);
397 struct ReductionRewriteInfo {
402 FailureOr<ReductionRewriteInfo>
static getReductionInfo(
403 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
409 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
410 if (!srcVectorType || srcVectorType.getRank() != 1)
414 extractAllElements(op, adaptor, srcVectorType, rewriter);
416 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
419 template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
420 typename SPIRVSMinOp>
425 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
428 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
429 if (
failed(reductionInfo))
432 auto [resultType, extractedElements] = *reductionInfo;
434 Value result = extractedElements.front();
435 for (
Value next : llvm::drop_begin(extractedElements)) {
436 switch (reduceOp.getKind()) {
438 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
439 case vector::CombiningKind::kind: \
440 if (llvm::isa<IntegerType>(resultType)) { \
441 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
443 assert(llvm::isa<FloatType>(resultType)); \
444 result = spirv::fop::create(rewriter, loc, resultType, result, next); \
448 #define INT_OR_FLOAT_CASE(kind, fop) \
449 case vector::CombiningKind::kind: \
450 result = fop::create(rewriter, loc, resultType, result, next); \
460 case vector::CombiningKind::AND:
461 case vector::CombiningKind::OR:
462 case vector::CombiningKind::XOR:
467 #undef INT_AND_FLOAT_CASE
468 #undef INT_OR_FLOAT_CASE
476 template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
477 struct VectorReductionFloatMinMax final
482 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
485 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
486 if (
failed(reductionInfo))
489 auto [resultType, extractedElements] = *reductionInfo;
491 Value result = extractedElements.front();
492 for (
Value next : llvm::drop_begin(extractedElements)) {
493 switch (reduceOp.getKind()) {
495 #define INT_OR_FLOAT_CASE(kind, fop) \
496 case vector::CombiningKind::kind: \
497 result = fop::create(rewriter, loc, resultType, result, next); \
508 #undef INT_OR_FLOAT_CASE
516 class VectorScalarBroadcastPattern final
522 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
524 if (isa<VectorType>(op.getSourceType())) {
526 op,
"only conversion of 'broadcast from scalar' is supported");
528 Type dstType = getTypeConverter()->convertType(op.getType());
531 if (isa<spirv::ScalarType>(dstType)) {
532 rewriter.
replaceOp(op, adaptor.getSource());
534 auto dstVecType = cast<VectorType>(dstType);
536 adaptor.getSource());
544 struct VectorShuffleOpConvert final
549 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
551 VectorType oldResultType = shuffleOp.getResultVectorType();
552 Type newResultType = getTypeConverter()->convertType(oldResultType);
555 "unsupported result vector type");
557 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
559 VectorType oldV1Type = shuffleOp.getV1VectorType();
560 VectorType oldV2Type = shuffleOp.getV2VectorType();
564 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
565 oldResultType.getNumElements() > 1) {
567 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
575 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
577 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
578 return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
581 assert(idx == 0 &&
"Invalid scalar element index");
585 int32_t numV1Elems = oldV1Type.getNumElements();
587 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
588 Value vec = adaptor.getV1();
589 int32_t elementIdx = shuffleIdx;
590 if (elementIdx >= numV1Elems) {
591 vec = adaptor.getV2();
592 elementIdx -= numV1Elems;
595 newOperand = getElementAtIdx(vec, elementIdx);
599 if (newOperands.size() == 1) {
600 rewriter.
replaceOp(shuffleOp, newOperands.front());
605 shuffleOp, newResultType, newOperands);
610 struct VectorInterleaveOpConvert final
615 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
618 VectorType oldResultType = interleaveOp.getResultVectorType();
619 Type newResultType = getTypeConverter()->convertType(oldResultType);
622 "unsupported result vector type");
625 VectorType sourceType = interleaveOp.getSourceVectorType();
626 int n = sourceType.getNumElements();
632 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
634 interleaveOp, newResultType, newOperands);
638 auto seq = llvm::seq<int64_t>(2 * n);
639 auto indices = llvm::map_to_vector(
640 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
644 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
651 struct VectorDeinterleaveOpConvert final
656 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
660 VectorType oldResultType = deinterleaveOp.getResultVectorType();
661 Type newResultType = getTypeConverter()->convertType(oldResultType);
664 "unsupported result vector type");
666 Location loc = deinterleaveOp->getLoc();
669 Value sourceVector = adaptor.getSource();
670 VectorType sourceType = deinterleaveOp.getSourceVectorType();
671 int n = sourceType.getNumElements();
677 auto elem0 = spirv::CompositeExtractOp::create(
678 rewriter, loc, newResultType, sourceVector,
681 auto elem1 = spirv::CompositeExtractOp::create(
682 rewriter, loc, newResultType, sourceVector,
685 rewriter.
replaceOp(deinterleaveOp, {elem0, elem1});
690 auto seqEven = llvm::seq<int64_t>(n / 2);
692 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
695 auto seqOdd = llvm::seq<int64_t>(n / 2);
697 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
700 auto shuffleEven = spirv::VectorShuffleOp::create(
701 rewriter, loc, newResultType, sourceVector, sourceVector,
704 auto shuffleOdd = spirv::VectorShuffleOp::create(
705 rewriter, loc, newResultType, sourceVector, sourceVector,
708 rewriter.
replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
713 struct VectorLoadOpConverter final
718 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
720 auto memrefType = loadOp.getMemRefType();
722 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
725 loadOp,
"expected spirv.storage_class memory space");
727 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
728 auto loc = loadOp.getLoc();
731 adaptor.getIndices(), loc, rewriter);
734 loadOp,
"failed to get memref element pointer");
736 spirv::StorageClass storageClass = attr.getValue();
737 auto vectorType = loadOp.getVectorType();
740 auto spirvVectorType = typeConverter.
convertType(vectorType);
741 if (!spirvVectorType)
746 std::optional<uint64_t> alignment = loadOp.getAlignment();
749 "invalid alignment requirement");
753 spirv::MemoryAccessAttr memoryAccessAttr;
754 IntegerAttr alignmentAttr;
755 if (alignment.has_value()) {
756 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
765 Value castedAccessChain =
766 (vectorType.getNumElements() == 1)
768 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
773 memoryAccessAttr, alignmentAttr);
779 struct VectorStoreOpConverter final
784 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
786 auto memrefType = storeOp.getMemRefType();
788 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
791 storeOp,
"expected spirv.storage_class memory space");
793 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
794 auto loc = storeOp.getLoc();
797 adaptor.getIndices(), loc, rewriter);
800 storeOp,
"failed to get memref element pointer");
802 std::optional<uint64_t> alignment = storeOp.getAlignment();
805 "invalid alignment requirement");
808 spirv::StorageClass storageClass = attr.getValue();
809 auto vectorType = storeOp.getVectorType();
815 Value castedAccessChain =
816 (vectorType.getNumElements() == 1)
818 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
822 spirv::MemoryAccessAttr memoryAccessAttr;
823 IntegerAttr alignmentAttr;
824 if (alignment.has_value()) {
825 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
832 storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
839 struct VectorReductionToIntDotProd final
843 LogicalResult matchAndRewrite(vector::ReductionOp op,
845 if (op.getKind() != vector::CombiningKind::ADD)
848 auto resultType = dyn_cast<IntegerType>(op.getType());
853 if (!llvm::is_contained({32, 64}, resultBitwidth))
856 VectorType inVecTy = op.getSourceVectorType();
857 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
858 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
861 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
864 op,
"reduction operand is not 'arith.muli'");
866 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
867 spirv::SDotAccSatOp,
false>(op, mul, rewriter)))
870 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
871 spirv::UDotAccSatOp,
false>(op, mul, rewriter)))
874 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
875 spirv::SUDotAccSatOp,
false>(op, mul, rewriter)))
878 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
879 spirv::SUDotAccSatOp,
true>(op, mul, rewriter)))
886 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
887 typename DotAccOp,
bool SwapOperands>
888 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
890 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
893 Value lhsIn = lhs.getIn();
894 auto lhsInType = cast<VectorType>(lhsIn.
getType());
895 if (!lhsInType.getElementType().isInteger(8))
898 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
901 Value rhsIn = rhs.getIn();
902 auto rhsInType = cast<VectorType>(rhsIn.
getType());
903 if (!rhsInType.getElementType().isInteger(8))
906 if (op.getSourceVectorType().getNumElements() == 3) {
907 IntegerType i8Type = rewriter.
getI8Type();
911 lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
913 rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
920 std::swap(lhsIn, rhsIn);
922 if (
Value acc = op.getAcc()) {
934 struct VectorReductionToFPDotProd final
939 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
941 if (op.getKind() != vector::CombiningKind::ADD)
944 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
948 Value vec = adaptor.getVector();
949 Value acc = adaptor.getAcc();
951 auto vectorType = dyn_cast<VectorType>(vec.
getType());
953 assert(isa<FloatType>(vec.
getType()) &&
954 "Expected the vector to be scalarized");
975 rewriter.
getFloatAttr(vectorType.getElementType(), 1.0);
977 rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
982 Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
984 res = spirv::FAddOp::create(rewriter, loc, acc, res);
995 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
997 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1003 int64_t numElements = stepOp.getType().getNumElements();
1009 if (numElements == 1) {
1016 source.reserve(numElements);
1017 for (int64_t i = 0; i < numElements; ++i) {
1020 spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1021 source.push_back(constOp);
1029 struct VectorToElementOpConvert final
1034 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1038 Location loc = toElementsOp.getLoc();
1043 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1044 results[0] = adaptor.getSource();
1045 rewriter.
replaceOp(toElementsOp, results);
1049 Type srcElementType = toElementsOp.getElements().getType().front();
1050 Type elementType = getTypeConverter()->convertType(srcElementType);
1054 llvm::formatv(
"failed to convert element type '{0}' to SPIR-V",
1057 for (
auto [idx, element] :
llvm::enumerate(toElementsOp.getElements())) {
1060 if (element.use_empty())
1063 Value result = spirv::CompositeExtractOp::create(
1064 rewriter, loc, elementType, adaptor.getSource(),
1066 results[idx] = result;
1069 rewriter.
replaceOp(toElementsOp, results);
1075 #define CL_INT_MAX_MIN_OPS \
1076 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1078 #define GL_INT_MAX_MIN_OPS \
1079 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1081 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1082 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1087 VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1088 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1089 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1090 VectorToElementOpConvert, VectorInsertOpConvert,
1091 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1092 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1093 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1094 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1095 VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
1096 VectorShuffleOpConvert, VectorInterleaveOpConvert,
1097 VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
1098 VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
1103 patterns.add<VectorReductionToFPDotProd>(typeConverter,
patterns.getContext(),
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...