29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/SmallVectorExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
44 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
52 if (
auto vectorType = dyn_cast<VectorType>(type))
53 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
63 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
65 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
71 if (dstType == adaptor.getSource().getType() ||
72 shapeCastOp.getResultVectorType().getNumElements() == 1) {
73 rewriter.
replaceOp(shapeCastOp, adaptor.getSource());
82 struct VectorBitcastConvert final
87 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
89 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
93 if (dstType == adaptor.getSource().getType()) {
94 rewriter.
replaceOp(bitcastOp, adaptor.getSource());
101 Type srcType = adaptor.getSource().getType();
105 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
110 adaptor.getSource());
115 struct VectorBroadcastConvert final
120 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
123 getTypeConverter()->convertType(castOp.getResultVectorType());
127 if (isa<spirv::ScalarType>(resultType)) {
128 rewriter.
replaceOp(castOp, adaptor.getSource());
133 adaptor.getSource());
148 int64_t kPoisonIndex,
unsigned vectorSize) {
149 if (llvm::isPowerOf2_32(vectorSize)) {
150 Value inBoundsMask = rewriter.
create<spirv::ConstantOp>(
153 return rewriter.
create<spirv::BitwiseAndOp>(loc, dynamicIndex,
156 Value poisonIndex = rewriter.
create<spirv::ConstantOp>(
160 rewriter.
create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161 return rewriter.
create<spirv::SelectOp>(
167 struct VectorExtractOpConvert final
172 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
174 Type dstType = getTypeConverter()->convertType(extractOp.getType());
178 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
179 rewriter.
replaceOp(extractOp, adaptor.getVector());
183 if (std::optional<int64_t>
id =
185 if (
id == vector::ExtractOp::kPoisonIndex)
188 "Static use of poison index handled elsewhere (folded to poison)");
190 extractOp, dstType, adaptor.getVector(),
193 Value sanitizedIndex = sanitizeDynamicIndex(
194 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
195 vector::ExtractOp::kPoisonIndex,
196 extractOp.getSourceVectorType().getNumElements());
198 extractOp, dstType, adaptor.getVector(), sanitizedIndex);
204 struct VectorExtractStridedSliceOpConvert final
209 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
211 Type dstType = getTypeConverter()->convertType(extractOp.getType());
221 Value srcVector = adaptor.getOperands().front();
224 if (isa<spirv::ScalarType>(dstType)) {
231 std::iota(indices.begin(), indices.end(), offset);
234 extractOp, dstType, srcVector, srcVector,
241 template <
class SPIRVFMAOp>
246 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
248 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
252 adaptor.getRhs(), adaptor.getAcc());
257 struct VectorFromElementsOpConvert final
262 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
264 Type resultType = getTypeConverter()->convertType(op.getType());
268 if (isa<spirv::ScalarType>(resultType)) {
276 assert(cast<VectorType>(resultType).getRank() == 1);
283 struct VectorInsertOpConvert final
288 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
290 if (isa<VectorType>(insertOp.getSourceType()))
292 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
294 "unsupported dest vector type");
297 if (insertOp.getSourceType().isIntOrFloat() &&
298 insertOp.getDestVectorType().getNumElements() == 1) {
299 rewriter.
replaceOp(insertOp, adaptor.getSource());
303 if (std::optional<int64_t>
id =
305 if (
id == vector::InsertOp::kPoisonIndex)
308 "Static use of poison index handled elsewhere (folded to poison)");
310 insertOp, adaptor.getSource(), adaptor.getDest(),
id.value());
312 Value sanitizedIndex = sanitizeDynamicIndex(
313 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
314 vector::InsertOp::kPoisonIndex,
315 insertOp.getDestVectorType().getNumElements());
317 insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
323 struct VectorExtractElementOpConvert final
328 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
330 Type resultType = getTypeConverter()->convertType(extractOp.getType());
334 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
335 rewriter.
replaceOp(extractOp, adaptor.getVector());
342 extractOp, resultType, adaptor.getVector(),
346 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
351 struct VectorInsertElementOpConvert final
356 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
358 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
362 if (isa<spirv::ScalarType>(vectorType)) {
363 rewriter.
replaceOp(insertOp, adaptor.getSource());
370 insertOp, adaptor.getSource(), adaptor.getDest(),
371 cstPos.getSExtValue());
374 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
375 adaptor.getPosition());
380 struct VectorInsertStridedSliceOpConvert final
385 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
387 Value srcVector = adaptor.getOperands().front();
388 Value dstVector = adaptor.getOperands().back();
395 if (isa<spirv::ScalarType>(srcVector.
getType())) {
396 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
398 insertOp, dstVector.
getType(), srcVector, dstVector,
403 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
404 uint64_t insertSize =
405 cast<VectorType>(srcVector.
getType()).getNumElements();
408 std::iota(indices.begin(), indices.end(), 0);
409 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
413 insertOp, dstVector.getType(), dstVector, srcVector,
421 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
423 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
425 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
428 for (
int i = 0; i < numElements; ++i) {
429 values.push_back(rewriter.
create<spirv::CompositeExtractOp>(
430 loc, srcVectorType.getElementType(), adaptor.getVector(),
433 if (
Value acc = adaptor.getAcc())
434 values.push_back(acc);
439 struct ReductionRewriteInfo {
444 FailureOr<ReductionRewriteInfo>
static getReductionInfo(
445 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
451 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
452 if (!srcVectorType || srcVectorType.getRank() != 1)
456 extractAllElements(op, adaptor, srcVectorType, rewriter);
458 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
461 template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
462 typename SPIRVSMinOp>
467 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
470 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
471 if (failed(reductionInfo))
474 auto [resultType, extractedElements] = *reductionInfo;
476 Value result = extractedElements.front();
477 for (
Value next : llvm::drop_begin(extractedElements)) {
478 switch (reduceOp.getKind()) {
480 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
481 case vector::CombiningKind::kind: \
482 if (llvm::isa<IntegerType>(resultType)) { \
483 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
485 assert(llvm::isa<FloatType>(resultType)); \
486 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
490 #define INT_OR_FLOAT_CASE(kind, fop) \
491 case vector::CombiningKind::kind: \
492 result = rewriter.create<fop>(loc, resultType, result, next); \
502 case vector::CombiningKind::AND:
503 case vector::CombiningKind::OR:
504 case vector::CombiningKind::XOR:
509 #undef INT_AND_FLOAT_CASE
510 #undef INT_OR_FLOAT_CASE
518 template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
519 struct VectorReductionFloatMinMax final
524 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
527 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
528 if (failed(reductionInfo))
531 auto [resultType, extractedElements] = *reductionInfo;
533 Value result = extractedElements.front();
534 for (
Value next : llvm::drop_begin(extractedElements)) {
535 switch (reduceOp.getKind()) {
537 #define INT_OR_FLOAT_CASE(kind, fop) \
538 case vector::CombiningKind::kind: \
539 result = rewriter.create<fop>(loc, resultType, result, next); \
550 #undef INT_OR_FLOAT_CASE
563 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
565 Type dstType = getTypeConverter()->convertType(op.getType());
568 if (isa<spirv::ScalarType>(dstType)) {
569 rewriter.
replaceOp(op, adaptor.getInput());
571 auto dstVecType = cast<VectorType>(dstType);
581 struct VectorShuffleOpConvert final
586 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
588 VectorType oldResultType = shuffleOp.getResultVectorType();
589 Type newResultType = getTypeConverter()->convertType(oldResultType);
592 "unsupported result vector type");
594 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
596 VectorType oldV1Type = shuffleOp.getV1VectorType();
597 VectorType oldV2Type = shuffleOp.getV2VectorType();
601 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
602 oldResultType.getNumElements() > 1) {
604 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
612 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
614 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
615 return rewriter.
create<spirv::CompositeExtractOp>(loc, scalarOrVec,
618 assert(idx == 0 &&
"Invalid scalar element index");
622 int32_t numV1Elems = oldV1Type.getNumElements();
624 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
625 Value vec = adaptor.getV1();
626 int32_t elementIdx = shuffleIdx;
627 if (elementIdx >= numV1Elems) {
628 vec = adaptor.getV2();
629 elementIdx -= numV1Elems;
632 newOperand = getElementAtIdx(vec, elementIdx);
636 if (newOperands.size() == 1) {
637 rewriter.
replaceOp(shuffleOp, newOperands.front());
642 shuffleOp, newResultType, newOperands);
647 struct VectorInterleaveOpConvert final
652 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
655 VectorType oldResultType = interleaveOp.getResultVectorType();
656 Type newResultType = getTypeConverter()->convertType(oldResultType);
659 "unsupported result vector type");
662 VectorType sourceType = interleaveOp.getSourceVectorType();
663 int n = sourceType.getNumElements();
669 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
671 interleaveOp, newResultType, newOperands);
675 auto seq = llvm::seq<int64_t>(2 * n);
676 auto indices = llvm::map_to_vector(
677 seq, [n](
int i) {
return (i % 2 ? n : 0) + i / 2; });
681 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
688 struct VectorDeinterleaveOpConvert final
693 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
697 VectorType oldResultType = deinterleaveOp.getResultVectorType();
698 Type newResultType = getTypeConverter()->convertType(oldResultType);
701 "unsupported result vector type");
703 Location loc = deinterleaveOp->getLoc();
706 Value sourceVector = adaptor.getSource();
707 VectorType sourceType = deinterleaveOp.getSourceVectorType();
708 int n = sourceType.getNumElements();
714 auto elem0 = rewriter.
create<spirv::CompositeExtractOp>(
717 auto elem1 = rewriter.
create<spirv::CompositeExtractOp>(
720 rewriter.
replaceOp(deinterleaveOp, {elem0, elem1});
725 auto seqEven = llvm::seq<int64_t>(n / 2);
727 llvm::map_to_vector(seqEven, [](
int i) {
return i * 2; });
730 auto seqOdd = llvm::seq<int64_t>(n / 2);
732 llvm::map_to_vector(seqOdd, [](
int i) {
return i * 2 + 1; });
735 auto shuffleEven = rewriter.
create<spirv::VectorShuffleOp>(
736 loc, newResultType, sourceVector, sourceVector,
739 auto shuffleOdd = rewriter.
create<spirv::VectorShuffleOp>(
740 loc, newResultType, sourceVector, sourceVector,
743 rewriter.
replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
748 struct VectorLoadOpConverter final
753 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
755 auto memrefType = loadOp.getMemRefType();
757 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
760 loadOp,
"expected spirv.storage_class memory space");
762 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
763 auto loc = loadOp.getLoc();
766 adaptor.getIndices(), loc, rewriter);
769 loadOp,
"failed to get memref element pointer");
771 spirv::StorageClass storageClass = attr.getValue();
772 auto vectorType = loadOp.getVectorType();
775 auto spirvVectorType = typeConverter.
convertType(vectorType);
781 Value castedAccessChain = (vectorType.getNumElements() == 1)
783 : rewriter.
create<spirv::BitcastOp>(
784 loc, vectorPtrType, accessChain);
793 struct VectorStoreOpConverter final
798 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
800 auto memrefType = storeOp.getMemRefType();
802 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
805 storeOp,
"expected spirv.storage_class memory space");
807 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
808 auto loc = storeOp.getLoc();
811 adaptor.getIndices(), loc, rewriter);
814 storeOp,
"failed to get memref element pointer");
816 spirv::StorageClass storageClass = attr.getValue();
817 auto vectorType = storeOp.getVectorType();
823 Value castedAccessChain = (vectorType.getNumElements() == 1)
825 : rewriter.
create<spirv::BitcastOp>(
826 loc, vectorPtrType, accessChain);
829 adaptor.getValueToStore());
835 struct VectorReductionToIntDotProd final
839 LogicalResult matchAndRewrite(vector::ReductionOp op,
841 if (op.getKind() != vector::CombiningKind::ADD)
844 auto resultType = dyn_cast<IntegerType>(op.getType());
849 if (!llvm::is_contained({32, 64}, resultBitwidth))
852 VectorType inVecTy = op.getSourceVectorType();
853 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
854 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
857 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
860 op,
"reduction operand is not 'arith.muli'");
862 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
863 spirv::SDotAccSatOp,
false>(op, mul, rewriter)))
866 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
867 spirv::UDotAccSatOp,
false>(op, mul, rewriter)))
870 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
871 spirv::SUDotAccSatOp,
false>(op, mul, rewriter)))
874 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
875 spirv::SUDotAccSatOp,
true>(op, mul, rewriter)))
882 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
883 typename DotAccOp,
bool SwapOperands>
884 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
886 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
889 Value lhsIn = lhs.getIn();
890 auto lhsInType = cast<VectorType>(lhsIn.
getType());
891 if (!lhsInType.getElementType().isInteger(8))
894 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
897 Value rhsIn = rhs.getIn();
898 auto rhsInType = cast<VectorType>(rhsIn.
getType());
899 if (!rhsInType.getElementType().isInteger(8))
902 if (op.getSourceVectorType().getNumElements() == 3) {
903 IntegerType i8Type = rewriter.
getI8Type();
907 lhsIn = rewriter.
create<spirv::CompositeConstructOp>(
909 rhsIn = rewriter.
create<spirv::CompositeConstructOp>(
916 std::swap(lhsIn, rhsIn);
918 if (
Value acc = op.getAcc()) {
930 struct VectorReductionToFPDotProd final
935 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
937 if (op.getKind() != vector::CombiningKind::ADD)
940 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
944 Value vec = adaptor.getVector();
945 Value acc = adaptor.getAcc();
947 auto vectorType = dyn_cast<VectorType>(vec.
getType());
949 assert(isa<FloatType>(vec.
getType()) &&
950 "Expected the vector to be scalarized");
971 rewriter.
getFloatAttr(vectorType.getElementType(), 1.0);
973 rhs = rewriter.
create<spirv::ConstantOp>(loc, vectorType, oneAttr);
978 Value res = rewriter.
create<spirv::DotOp>(loc, resultType, lhs, rhs);
980 res = rewriter.
create<spirv::FAddOp>(loc, acc, res);
991 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
993 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
999 int64_t numElements = stepOp.getType().getNumElements();
1005 if (numElements == 1) {
1012 source.reserve(numElements);
1013 for (int64_t i = 0; i < numElements; ++i) {
1015 Value constOp = rewriter.
create<spirv::ConstantOp>(loc, intType, intAttr);
1016 source.push_back(constOp);
1025 #define CL_INT_MAX_MIN_OPS \
1026 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1028 #define GL_INT_MAX_MIN_OPS \
1029 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1031 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1032 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1037 VectorBitcastConvert, VectorBroadcastConvert,
1038 VectorExtractElementOpConvert, VectorExtractOpConvert,
1039 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1040 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1041 VectorInsertElementOpConvert, VectorInsertOpConvert,
1042 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1043 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1044 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1045 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1046 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1047 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1048 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
1049 VectorStepOpConvert>(typeConverter,
patterns.getContext(),
1054 patterns.add<VectorReductionToFPDotProd>(typeConverter,
patterns.getContext(),
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...