28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/SmallVectorExtras.h"
32 #include "llvm/Support/FormatVariadic.h"
43 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
51 if (
auto vectorType = dyn_cast<VectorType>(type))
52 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
62 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
64 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
70 if (dstType == adaptor.getSource().getType() ||
71 shapeCastOp.getResultVectorType().getNumElements() == 1) {
72 rewriter.
replaceOp(shapeCastOp, adaptor.getSource());
81 struct VectorBitcastConvert final
86 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
88 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
92 if (dstType == adaptor.getSource().getType()) {
93 rewriter.
replaceOp(bitcastOp, adaptor.getSource());
100 Type srcType = adaptor.getSource().getType();
104 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
109 adaptor.getSource());
114 struct VectorBroadcastConvert final
119 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
122 getTypeConverter()->convertType(castOp.getResultVectorType());
126 if (isa<spirv::ScalarType>(resultType)) {
127 rewriter.
replaceOp(castOp, adaptor.getSource());
132 adaptor.getSource());
147 int64_t kPoisonIndex,
unsigned vectorSize) {
148 if (llvm::isPowerOf2_32(vectorSize)) {
149 Value inBoundsMask = spirv::ConstantOp::create(
150 rewriter, loc, dynamicIndex.
getType(),
152 return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
155 Value poisonIndex = spirv::ConstantOp::create(
156 rewriter, loc, dynamicIndex.
getType(),
159 spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
160 return spirv::SelectOp::create(
161 rewriter, loc, cmpResult,
166 struct VectorExtractOpConvert final
171 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
173 Type dstType = getTypeConverter()->convertType(extractOp.getType());
177 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
178 rewriter.
replaceOp(extractOp, adaptor.getSource());
182 if (std::optional<int64_t>
id =
184 if (
id == vector::ExtractOp::kPoisonIndex)
187 "Static use of poison index handled elsewhere (folded to poison)");
189 extractOp, dstType, adaptor.getSource(),
192 Value sanitizedIndex = sanitizeDynamicIndex(
193 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
194 vector::ExtractOp::kPoisonIndex,
195 extractOp.getSourceVectorType().getNumElements());
197 extractOp, dstType, adaptor.getSource(), sanitizedIndex);
203 struct VectorExtractStridedSliceOpConvert final
208 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
210 Type dstType = getTypeConverter()->convertType(extractOp.getType());
220 Value srcVector = adaptor.getOperands().front();
223 if (isa<spirv::ScalarType>(dstType)) {
230 std::iota(indices.begin(), indices.end(), offset);
233 extractOp, dstType, srcVector, srcVector,
240 template <
class SPIRVFMAOp>
245 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
247 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
251 adaptor.getRhs(), adaptor.getAcc());
256 struct VectorFromElementsOpConvert final
261 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
263 Type resultType = getTypeConverter()->convertType(op.getType());
267 if (isa<spirv::ScalarType>(resultType)) {
275 assert(cast<VectorType>(resultType).getRank() == 1);
282 struct VectorInsertOpConvert final
287 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
289 if (isa<VectorType>(insertOp.getValueToStoreType()))
291 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
293 "unsupported dest vector type");
296 if (insertOp.getValueToStoreType().isIntOrFloat() &&
297 insertOp.getDestVectorType().getNumElements() == 1) {
298 rewriter.
replaceOp(insertOp, adaptor.getValueToStore());
302 if (std::optional<int64_t>
id =
304 if (
id == vector::InsertOp::kPoisonIndex)
307 "Static use of poison index handled elsewhere (folded to poison)");
309 insertOp, adaptor.getValueToStore(), adaptor.getDest(),
id.value());
311 Value sanitizedIndex = sanitizeDynamicIndex(
312 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
313 vector::InsertOp::kPoisonIndex,
314 insertOp.getDestVectorType().getNumElements());
316 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
323 struct VectorInsertStridedSliceOpConvert final
328 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
330 Value srcVector = adaptor.getOperands().front();
331 Value dstVector = adaptor.getOperands().back();
338 if (isa<spirv::ScalarType>(srcVector.
getType())) {
339 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
341 insertOp, dstVector.
getType(), srcVector, dstVector,
346 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
347 uint64_t insertSize =
348 cast<VectorType>(srcVector.
getType()).getNumElements();
351 std::iota(indices.begin(), indices.end(), 0);
352 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
356 insertOp, dstVector.getType(), dstVector, srcVector,
364 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
366 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
368 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
371 for (
int i = 0; i < numElements; ++i) {
372 values.push_back(spirv::CompositeExtractOp::create(
373 rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
376 if (
Value acc = adaptor.getAcc())
377 values.push_back(acc);
382 struct ReductionRewriteInfo {
387 FailureOr<ReductionRewriteInfo>
static getReductionInfo(
388 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
394 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
395 if (!srcVectorType || srcVectorType.getRank() != 1)
399 extractAllElements(op, adaptor, srcVectorType, rewriter);
401 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
404 template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
405 typename SPIRVSMinOp>
410 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
413 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
414 if (
failed(reductionInfo))
417 auto [resultType, extractedElements] = *reductionInfo;
419 Value result = extractedElements.front();
420 for (
Value next : llvm::drop_begin(extractedElements)) {
421 switch (reduceOp.getKind()) {
423 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
424 case vector::CombiningKind::kind: \
425 if (llvm::isa<IntegerType>(resultType)) { \
426 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
428 assert(llvm::isa<FloatType>(resultType)); \
429 result = spirv::fop::create(rewriter, loc, resultType, result, next); \
433 #define INT_OR_FLOAT_CASE(kind, fop) \
434 case vector::CombiningKind::kind: \
435 result = fop::create(rewriter, loc, resultType, result, next); \
445 case vector::CombiningKind::AND:
446 case vector::CombiningKind::OR:
447 case vector::CombiningKind::XOR:
452 #undef INT_AND_FLOAT_CASE
453 #undef INT_OR_FLOAT_CASE
461 template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
462 struct VectorReductionFloatMinMax final
467 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
470 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
471 if (
failed(reductionInfo))
474 auto [resultType, extractedElements] = *reductionInfo;
476 Value result = extractedElements.front();
477 for (
Value next : llvm::drop_begin(extractedElements)) {
478 switch (reduceOp.getKind()) {
480 #define INT_OR_FLOAT_CASE(kind, fop) \
481 case vector::CombiningKind::kind: \
482 result = fop::create(rewriter, loc, resultType, result, next); \
493 #undef INT_OR_FLOAT_CASE
501 class VectorScalarBroadcastPattern final
507 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
509 if (isa<VectorType>(op.getSourceType())) {
511 op,
"only conversion of 'broadcast from scalar' is supported");
513 Type dstType = getTypeConverter()->convertType(op.getType());
516 if (isa<spirv::ScalarType>(dstType)) {
517 rewriter.
replaceOp(op, adaptor.getSource());
519 auto dstVecType = cast<VectorType>(dstType);
521 adaptor.getSource());
529 struct VectorShuffleOpConvert final
534 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
536 VectorType oldResultType = shuffleOp.getResultVectorType();
537 Type newResultType = getTypeConverter()->convertType(oldResultType);
540 "unsupported result vector type");
542 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
544 VectorType oldV1Type = shuffleOp.getV1VectorType();
545 VectorType oldV2Type = shuffleOp.getV2VectorType();
549 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
550 oldResultType.getNumElements() > 1) {
552 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
560 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
562 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
563 return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
566 assert(idx == 0 &&
"Invalid scalar element index");
570 int32_t numV1Elems = oldV1Type.getNumElements();
572 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
573 Value vec = adaptor.getV1();
574 int32_t elementIdx = shuffleIdx;
575 if (elementIdx >= numV1Elems) {
576 vec = adaptor.getV2();
577 elementIdx -= numV1Elems;
580 newOperand = getElementAtIdx(vec, elementIdx);
584 if (newOperands.size() == 1) {
585 rewriter.
replaceOp(shuffleOp, newOperands.front());
590 shuffleOp, newResultType, newOperands);
595 struct VectorInterleaveOpConvert final
600 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
603 VectorType oldResultType = interleaveOp.getResultVectorType();
604 Type newResultType = getTypeConverter()->convertType(oldResultType);
607 "unsupported result vector type");
610 VectorType sourceType = interleaveOp.getSourceVectorType();
611 int n = sourceType.getNumElements();
617 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
619 interleaveOp, newResultType, newOperands);
623 auto seq = llvm::seq<int64_t>(2 * n);
624 auto indices = llvm::map_to_vector(
625 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
629 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
636 struct VectorDeinterleaveOpConvert final
641 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
645 VectorType oldResultType = deinterleaveOp.getResultVectorType();
646 Type newResultType = getTypeConverter()->convertType(oldResultType);
649 "unsupported result vector type");
651 Location loc = deinterleaveOp->getLoc();
654 Value sourceVector = adaptor.getSource();
655 VectorType sourceType = deinterleaveOp.getSourceVectorType();
656 int n = sourceType.getNumElements();
662 auto elem0 = spirv::CompositeExtractOp::create(
663 rewriter, loc, newResultType, sourceVector,
666 auto elem1 = spirv::CompositeExtractOp::create(
667 rewriter, loc, newResultType, sourceVector,
670 rewriter.
replaceOp(deinterleaveOp, {elem0, elem1});
675 auto seqEven = llvm::seq<int64_t>(n / 2);
677 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
680 auto seqOdd = llvm::seq<int64_t>(n / 2);
682 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
685 auto shuffleEven = spirv::VectorShuffleOp::create(
686 rewriter, loc, newResultType, sourceVector, sourceVector,
689 auto shuffleOdd = spirv::VectorShuffleOp::create(
690 rewriter, loc, newResultType, sourceVector, sourceVector,
693 rewriter.
replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
698 struct VectorLoadOpConverter final
703 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
705 auto memrefType = loadOp.getMemRefType();
707 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
710 loadOp,
"expected spirv.storage_class memory space");
712 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
713 auto loc = loadOp.getLoc();
716 adaptor.getIndices(), loc, rewriter);
719 loadOp,
"failed to get memref element pointer");
721 spirv::StorageClass storageClass = attr.getValue();
722 auto vectorType = loadOp.getVectorType();
725 auto spirvVectorType = typeConverter.
convertType(vectorType);
726 if (!spirvVectorType)
731 std::optional<uint64_t> alignment = loadOp.getAlignment();
734 "invalid alignment requirement");
738 spirv::MemoryAccessAttr memoryAccessAttr;
739 IntegerAttr alignmentAttr;
740 if (alignment.has_value()) {
741 memoryAccess |= spirv::MemoryAccess::Aligned;
750 Value castedAccessChain =
751 (vectorType.getNumElements() == 1)
753 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
758 memoryAccessAttr, alignmentAttr);
764 struct VectorStoreOpConverter final
769 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
771 auto memrefType = storeOp.getMemRefType();
773 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
776 storeOp,
"expected spirv.storage_class memory space");
778 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
779 auto loc = storeOp.getLoc();
782 adaptor.getIndices(), loc, rewriter);
785 storeOp,
"failed to get memref element pointer");
787 std::optional<uint64_t> alignment = storeOp.getAlignment();
790 "invalid alignment requirement");
793 spirv::StorageClass storageClass = attr.getValue();
794 auto vectorType = storeOp.getVectorType();
800 Value castedAccessChain =
801 (vectorType.getNumElements() == 1)
803 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
807 spirv::MemoryAccessAttr memoryAccessAttr;
808 IntegerAttr alignmentAttr;
809 if (alignment.has_value()) {
810 memoryAccess |= spirv::MemoryAccess::Aligned;
817 storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
824 struct VectorReductionToIntDotProd final
828 LogicalResult matchAndRewrite(vector::ReductionOp op,
830 if (op.getKind() != vector::CombiningKind::ADD)
833 auto resultType = dyn_cast<IntegerType>(op.getType());
838 if (!llvm::is_contained({32, 64}, resultBitwidth))
841 VectorType inVecTy = op.getSourceVectorType();
842 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
843 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
846 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
849 op,
"reduction operand is not 'arith.muli'");
851 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
852 spirv::SDotAccSatOp,
false>(op, mul, rewriter)))
855 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
856 spirv::UDotAccSatOp,
false>(op, mul, rewriter)))
859 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
860 spirv::SUDotAccSatOp,
false>(op, mul, rewriter)))
863 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
864 spirv::SUDotAccSatOp,
true>(op, mul, rewriter)))
871 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
872 typename DotAccOp,
bool SwapOperands>
873 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
875 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
878 Value lhsIn = lhs.getIn();
879 auto lhsInType = cast<VectorType>(lhsIn.
getType());
880 if (!lhsInType.getElementType().isInteger(8))
883 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
886 Value rhsIn = rhs.getIn();
887 auto rhsInType = cast<VectorType>(rhsIn.
getType());
888 if (!rhsInType.getElementType().isInteger(8))
891 if (op.getSourceVectorType().getNumElements() == 3) {
892 IntegerType i8Type = rewriter.
getI8Type();
896 lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
898 rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
905 std::swap(lhsIn, rhsIn);
907 if (
Value acc = op.getAcc()) {
919 struct VectorReductionToFPDotProd final
924 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
926 if (op.getKind() != vector::CombiningKind::ADD)
929 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
933 Value vec = adaptor.getVector();
934 Value acc = adaptor.getAcc();
936 auto vectorType = dyn_cast<VectorType>(vec.
getType());
938 assert(isa<FloatType>(vec.
getType()) &&
939 "Expected the vector to be scalarized");
960 rewriter.
getFloatAttr(vectorType.getElementType(), 1.0);
962 rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
967 Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
969 res = spirv::FAddOp::create(rewriter, loc, acc, res);
980 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
982 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
988 int64_t numElements = stepOp.getType().getNumElements();
994 if (numElements == 1) {
1001 source.reserve(numElements);
1002 for (int64_t i = 0; i < numElements; ++i) {
1005 spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1006 source.push_back(constOp);
1014 struct VectorToElementOpConvert final
1019 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1023 Location loc = toElementsOp.getLoc();
1028 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1029 results[0] = adaptor.getSource();
1030 rewriter.
replaceOp(toElementsOp, results);
1034 Type srcElementType = toElementsOp.getElements().getType().front();
1035 Type elementType = getTypeConverter()->convertType(srcElementType);
1039 llvm::formatv(
"failed to convert element type '{0}' to SPIR-V",
1042 for (
auto [idx, element] :
llvm::enumerate(toElementsOp.getElements())) {
1045 if (element.use_empty())
1048 Value result = spirv::CompositeExtractOp::create(
1049 rewriter, loc, elementType, adaptor.getSource(),
1051 results[idx] = result;
1054 rewriter.
replaceOp(toElementsOp, results);
1060 #define CL_INT_MAX_MIN_OPS \
1061 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1063 #define GL_INT_MAX_MIN_OPS \
1064 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1066 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1067 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1072 VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1073 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1074 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1075 VectorToElementOpConvert, VectorInsertOpConvert,
1076 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1077 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1078 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1079 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1080 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1081 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1082 VectorScalarBroadcastPattern, VectorLoadOpConverter,
1083 VectorStoreOpConverter, VectorStepOpConvert>(
1088 patterns.add<VectorReductionToFPDotProd>(typeConverter,
patterns.getContext(),
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...