29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
44 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
52 if (
auto vectorType = dyn_cast<VectorType>(type))
53 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
63 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
65 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
71 if (dstType == adaptor.getSource().getType() ||
72 shapeCastOp.getResultVectorType().getNumElements() == 1) {
73 rewriter.
replaceOp(shapeCastOp, adaptor.getSource());
82 struct VectorBitcastConvert final
87 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
89 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
93 if (dstType == adaptor.getSource().getType()) {
94 rewriter.
replaceOp(bitcastOp, adaptor.getSource());
101 Type srcType = adaptor.getSource().getType();
105 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
110 adaptor.getSource());
115 struct VectorBroadcastConvert final
120 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
123 getTypeConverter()->convertType(castOp.getResultVectorType());
127 if (isa<spirv::ScalarType>(resultType)) {
128 rewriter.
replaceOp(castOp, adaptor.getSource());
133 adaptor.getSource());
148 int64_t kPoisonIndex,
unsigned vectorSize) {
149 if (llvm::isPowerOf2_32(vectorSize)) {
150 Value inBoundsMask = rewriter.
create<spirv::ConstantOp>(
153 return rewriter.
create<spirv::BitwiseAndOp>(loc, dynamicIndex,
156 Value poisonIndex = rewriter.
create<spirv::ConstantOp>(
160 rewriter.
create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161 return rewriter.
create<spirv::SelectOp>(
167 struct VectorExtractOpConvert final
172 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
174 Type dstType = getTypeConverter()->convertType(extractOp.getType());
178 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
179 rewriter.
replaceOp(extractOp, adaptor.getVector());
183 if (std::optional<int64_t>
id =
185 if (
id == vector::ExtractOp::kPoisonIndex)
188 "Static use of poison index handled elsewhere (folded to poison)");
190 extractOp, dstType, adaptor.getVector(),
193 Value sanitizedIndex = sanitizeDynamicIndex(
194 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
195 vector::ExtractOp::kPoisonIndex,
196 extractOp.getSourceVectorType().getNumElements());
198 extractOp, dstType, adaptor.getVector(), sanitizedIndex);
204 struct VectorExtractStridedSliceOpConvert final
209 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
211 Type dstType = getTypeConverter()->convertType(extractOp.getType());
221 Value srcVector = adaptor.getOperands().front();
224 if (isa<spirv::ScalarType>(dstType)) {
231 std::iota(indices.begin(), indices.end(), offset);
234 extractOp, dstType, srcVector, srcVector,
241 template <
class SPIRVFMAOp>
246 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
248 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
252 adaptor.getRhs(), adaptor.getAcc());
257 struct VectorFromElementsOpConvert final
262 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
264 Type resultType = getTypeConverter()->convertType(op.getType());
268 if (isa<spirv::ScalarType>(resultType)) {
276 assert(cast<VectorType>(resultType).getRank() == 1);
283 struct VectorInsertOpConvert final
288 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
290 if (isa<VectorType>(insertOp.getValueToStoreType()))
292 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
294 "unsupported dest vector type");
297 if (insertOp.getValueToStoreType().isIntOrFloat() &&
298 insertOp.getDestVectorType().getNumElements() == 1) {
299 rewriter.
replaceOp(insertOp, adaptor.getValueToStore());
303 if (std::optional<int64_t>
id =
305 if (
id == vector::InsertOp::kPoisonIndex)
308 "Static use of poison index handled elsewhere (folded to poison)");
310 insertOp, adaptor.getValueToStore(), adaptor.getDest(),
id.value());
312 Value sanitizedIndex = sanitizeDynamicIndex(
313 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
314 vector::InsertOp::kPoisonIndex,
315 insertOp.getDestVectorType().getNumElements());
317 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
324 struct VectorExtractElementOpConvert final
329 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
331 Type resultType = getTypeConverter()->convertType(extractOp.getType());
335 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
336 rewriter.
replaceOp(extractOp, adaptor.getVector());
343 extractOp, resultType, adaptor.getVector(),
347 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
352 struct VectorInsertElementOpConvert final
357 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
359 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
363 if (isa<spirv::ScalarType>(vectorType)) {
364 rewriter.
replaceOp(insertOp, adaptor.getSource());
371 insertOp, adaptor.getSource(), adaptor.getDest(),
372 cstPos.getSExtValue());
375 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
376 adaptor.getPosition());
381 struct VectorInsertStridedSliceOpConvert final
386 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
388 Value srcVector = adaptor.getOperands().front();
389 Value dstVector = adaptor.getOperands().back();
396 if (isa<spirv::ScalarType>(srcVector.
getType())) {
397 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
399 insertOp, dstVector.
getType(), srcVector, dstVector,
404 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
405 uint64_t insertSize =
406 cast<VectorType>(srcVector.
getType()).getNumElements();
409 std::iota(indices.begin(), indices.end(), 0);
410 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
414 insertOp, dstVector.getType(), dstVector, srcVector,
422 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
424 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
426 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
429 for (
int i = 0; i < numElements; ++i) {
430 values.push_back(rewriter.
create<spirv::CompositeExtractOp>(
431 loc, srcVectorType.getElementType(), adaptor.getVector(),
434 if (
Value acc = adaptor.getAcc())
435 values.push_back(acc);
440 struct ReductionRewriteInfo {
445 FailureOr<ReductionRewriteInfo>
static getReductionInfo(
446 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
452 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
453 if (!srcVectorType || srcVectorType.getRank() != 1)
457 extractAllElements(op, adaptor, srcVectorType, rewriter);
459 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
462 template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
463 typename SPIRVSMinOp>
468 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
471 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
472 if (failed(reductionInfo))
475 auto [resultType, extractedElements] = *reductionInfo;
477 Value result = extractedElements.front();
478 for (
Value next : llvm::drop_begin(extractedElements)) {
479 switch (reduceOp.getKind()) {
481 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
482 case vector::CombiningKind::kind: \
483 if (llvm::isa<IntegerType>(resultType)) { \
484 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
486 assert(llvm::isa<FloatType>(resultType)); \
487 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
491 #define INT_OR_FLOAT_CASE(kind, fop) \
492 case vector::CombiningKind::kind: \
493 result = rewriter.create<fop>(loc, resultType, result, next); \
503 case vector::CombiningKind::AND:
504 case vector::CombiningKind::OR:
505 case vector::CombiningKind::XOR:
510 #undef INT_AND_FLOAT_CASE
511 #undef INT_OR_FLOAT_CASE
519 template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
520 struct VectorReductionFloatMinMax final
525 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
528 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
529 if (failed(reductionInfo))
532 auto [resultType, extractedElements] = *reductionInfo;
534 Value result = extractedElements.front();
535 for (
Value next : llvm::drop_begin(extractedElements)) {
536 switch (reduceOp.getKind()) {
538 #define INT_OR_FLOAT_CASE(kind, fop) \
539 case vector::CombiningKind::kind: \
540 result = rewriter.create<fop>(loc, resultType, result, next); \
551 #undef INT_OR_FLOAT_CASE
564 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
566 Type dstType = getTypeConverter()->convertType(op.getType());
569 if (isa<spirv::ScalarType>(dstType)) {
570 rewriter.
replaceOp(op, adaptor.getInput());
572 auto dstVecType = cast<VectorType>(dstType);
582 struct VectorShuffleOpConvert final
587 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
589 VectorType oldResultType = shuffleOp.getResultVectorType();
590 Type newResultType = getTypeConverter()->convertType(oldResultType);
593 "unsupported result vector type");
595 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
597 VectorType oldV1Type = shuffleOp.getV1VectorType();
598 VectorType oldV2Type = shuffleOp.getV2VectorType();
602 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
603 oldResultType.getNumElements() > 1) {
605 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
613 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
615 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
616 return rewriter.
create<spirv::CompositeExtractOp>(loc, scalarOrVec,
619 assert(idx == 0 &&
"Invalid scalar element index");
623 int32_t numV1Elems = oldV1Type.getNumElements();
625 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
626 Value vec = adaptor.getV1();
627 int32_t elementIdx = shuffleIdx;
628 if (elementIdx >= numV1Elems) {
629 vec = adaptor.getV2();
630 elementIdx -= numV1Elems;
633 newOperand = getElementAtIdx(vec, elementIdx);
637 if (newOperands.size() == 1) {
638 rewriter.
replaceOp(shuffleOp, newOperands.front());
643 shuffleOp, newResultType, newOperands);
648 struct VectorInterleaveOpConvert final
653 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
656 VectorType oldResultType = interleaveOp.getResultVectorType();
657 Type newResultType = getTypeConverter()->convertType(oldResultType);
660 "unsupported result vector type");
663 VectorType sourceType = interleaveOp.getSourceVectorType();
664 int n = sourceType.getNumElements();
670 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
672 interleaveOp, newResultType, newOperands);
676 auto seq = llvm::seq<int64_t>(2 * n);
677 auto indices = llvm::map_to_vector(
678 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
682 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
689 struct VectorDeinterleaveOpConvert final
694 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
698 VectorType oldResultType = deinterleaveOp.getResultVectorType();
699 Type newResultType = getTypeConverter()->convertType(oldResultType);
702 "unsupported result vector type");
704 Location loc = deinterleaveOp->getLoc();
707 Value sourceVector = adaptor.getSource();
708 VectorType sourceType = deinterleaveOp.getSourceVectorType();
709 int n = sourceType.getNumElements();
715 auto elem0 = rewriter.
create<spirv::CompositeExtractOp>(
718 auto elem1 = rewriter.
create<spirv::CompositeExtractOp>(
721 rewriter.
replaceOp(deinterleaveOp, {elem0, elem1});
726 auto seqEven = llvm::seq<int64_t>(n / 2);
728 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
731 auto seqOdd = llvm::seq<int64_t>(n / 2);
733 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
736 auto shuffleEven = rewriter.
create<spirv::VectorShuffleOp>(
737 loc, newResultType, sourceVector, sourceVector,
740 auto shuffleOdd = rewriter.
create<spirv::VectorShuffleOp>(
741 loc, newResultType, sourceVector, sourceVector,
744 rewriter.
replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
749 struct VectorLoadOpConverter final
754 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
756 auto memrefType = loadOp.getMemRefType();
758 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
761 loadOp,
"expected spirv.storage_class memory space");
763 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
764 auto loc = loadOp.getLoc();
767 adaptor.getIndices(), loc, rewriter);
770 loadOp,
"failed to get memref element pointer");
772 spirv::StorageClass storageClass = attr.getValue();
773 auto vectorType = loadOp.getVectorType();
776 auto spirvVectorType = typeConverter.
convertType(vectorType);
782 Value castedAccessChain = (vectorType.getNumElements() == 1)
784 : rewriter.
create<spirv::BitcastOp>(
785 loc, vectorPtrType, accessChain);
794 struct VectorStoreOpConverter final
799 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
801 auto memrefType = storeOp.getMemRefType();
803 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
806 storeOp,
"expected spirv.storage_class memory space");
808 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
809 auto loc = storeOp.getLoc();
812 adaptor.getIndices(), loc, rewriter);
815 storeOp,
"failed to get memref element pointer");
817 spirv::StorageClass storageClass = attr.getValue();
818 auto vectorType = storeOp.getVectorType();
824 Value castedAccessChain = (vectorType.getNumElements() == 1)
826 : rewriter.
create<spirv::BitcastOp>(
827 loc, vectorPtrType, accessChain);
830 adaptor.getValueToStore());
836 struct VectorReductionToIntDotProd final
840 LogicalResult matchAndRewrite(vector::ReductionOp op,
842 if (op.getKind() != vector::CombiningKind::ADD)
845 auto resultType = dyn_cast<IntegerType>(op.getType());
850 if (!llvm::is_contained({32, 64}, resultBitwidth))
853 VectorType inVecTy = op.getSourceVectorType();
854 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
855 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
858 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
861 op,
"reduction operand is not 'arith.muli'");
863 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
864 spirv::SDotAccSatOp,
false>(op, mul, rewriter)))
867 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
868 spirv::UDotAccSatOp,
false>(op, mul, rewriter)))
871 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
872 spirv::SUDotAccSatOp,
false>(op, mul, rewriter)))
875 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
876 spirv::SUDotAccSatOp,
true>(op, mul, rewriter)))
883 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
884 typename DotAccOp,
bool SwapOperands>
885 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
887 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
890 Value lhsIn = lhs.getIn();
891 auto lhsInType = cast<VectorType>(lhsIn.
getType());
892 if (!lhsInType.getElementType().isInteger(8))
895 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
898 Value rhsIn = rhs.getIn();
899 auto rhsInType = cast<VectorType>(rhsIn.
getType());
900 if (!rhsInType.getElementType().isInteger(8))
903 if (op.getSourceVectorType().getNumElements() == 3) {
904 IntegerType i8Type = rewriter.
getI8Type();
908 lhsIn = rewriter.
create<spirv::CompositeConstructOp>(
910 rhsIn = rewriter.
create<spirv::CompositeConstructOp>(
917 std::swap(lhsIn, rhsIn);
919 if (
Value acc = op.getAcc()) {
931 struct VectorReductionToFPDotProd final
936 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
938 if (op.getKind() != vector::CombiningKind::ADD)
941 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
945 Value vec = adaptor.getVector();
946 Value acc = adaptor.getAcc();
948 auto vectorType = dyn_cast<VectorType>(vec.
getType());
950 assert(isa<FloatType>(vec.
getType()) &&
951 "Expected the vector to be scalarized");
972 rewriter.
getFloatAttr(vectorType.getElementType(), 1.0);
974 rhs = rewriter.
create<spirv::ConstantOp>(loc, vectorType, oneAttr);
979 Value res = rewriter.
create<spirv::DotOp>(loc, resultType, lhs, rhs);
981 res = rewriter.
create<spirv::FAddOp>(loc, acc, res);
992 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
994 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1000 int64_t numElements = stepOp.getType().getNumElements();
1006 if (numElements == 1) {
1013 source.reserve(numElements);
1014 for (int64_t i = 0; i < numElements; ++i) {
1016 Value constOp = rewriter.
create<spirv::ConstantOp>(loc, intType, intAttr);
1017 source.push_back(constOp);
1026 #define CL_INT_MAX_MIN_OPS \
1027 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1029 #define GL_INT_MAX_MIN_OPS \
1030 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1032 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1033 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1038 VectorBitcastConvert, VectorBroadcastConvert,
1039 VectorExtractElementOpConvert, VectorExtractOpConvert,
1040 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1041 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1042 VectorInsertElementOpConvert, VectorInsertOpConvert,
1043 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1044 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1045 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1046 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1047 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1048 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1049 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
1050 VectorStepOpConvert>(typeConverter,
patterns.getContext(),
1055 patterns.add<VectorReductionToFPDotProd>(typeConverter,
patterns.getContext(),
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...