30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/SmallVectorExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
45 return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
48 return cast<IntegerAttr>(attr[0]).getInt();
51 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
54 auto attr = foldResults[0].dyn_cast<
Attribute>();
66 if (
auto vectorType = dyn_cast<VectorType>(type))
67 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
77 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
79 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
85 if (dstType == adaptor.getSource().getType() ||
86 shapeCastOp.getResultVectorType().getNumElements() == 1) {
87 rewriter.
replaceOp(shapeCastOp, adaptor.getSource());
96 struct VectorBitcastConvert final
101 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
103 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
107 if (dstType == adaptor.getSource().getType()) {
108 rewriter.
replaceOp(bitcastOp, adaptor.getSource());
115 Type srcType = adaptor.getSource().getType();
119 llvm::formatv(
"different source ({0}) and target ({1}) bitwidth",
124 adaptor.getSource());
129 struct VectorBroadcastConvert final
134 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
137 getTypeConverter()->convertType(castOp.getResultVectorType());
141 if (isa<spirv::ScalarType>(resultType)) {
142 rewriter.
replaceOp(castOp, adaptor.getSource());
147 adaptor.getSource());
149 castOp, castOp.getResultVectorType(), source);
154 struct VectorExtractOpConvert final
159 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
161 if (extractOp.hasDynamicPosition())
164 Type dstType = getTypeConverter()->convertType(extractOp.getType());
168 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
169 rewriter.
replaceOp(extractOp, adaptor.getVector());
175 extractOp, adaptor.getVector(), id);
180 struct VectorExtractStridedSliceOpConvert final
185 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
187 Type dstType = getTypeConverter()->convertType(extractOp.getType());
197 Value srcVector = adaptor.getOperands().front();
200 if (isa<spirv::ScalarType>(dstType)) {
207 std::iota(indices.begin(), indices.end(), offset);
210 extractOp, dstType, srcVector, srcVector,
217 template <
class SPIRVFMAOp>
222 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
224 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
228 adaptor.getRhs(), adaptor.getAcc());
233 struct VectorInsertOpConvert final
238 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
240 if (isa<VectorType>(insertOp.getSourceType()))
242 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
244 "unsupported dest vector type");
247 if (insertOp.getSourceType().isIntOrFloat() &&
248 insertOp.getDestVectorType().getNumElements() == 1) {
249 rewriter.
replaceOp(insertOp, adaptor.getSource());
255 insertOp, adaptor.getSource(), adaptor.getDest(), id);
260 struct VectorExtractElementOpConvert final
265 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
267 Type resultType = getTypeConverter()->convertType(extractOp.getType());
271 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
272 rewriter.
replaceOp(extractOp, adaptor.getVector());
279 extractOp, resultType, adaptor.getVector(),
283 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
288 struct VectorInsertElementOpConvert final
293 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
295 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
299 if (isa<spirv::ScalarType>(vectorType)) {
300 rewriter.
replaceOp(insertOp, adaptor.getSource());
307 insertOp, adaptor.getSource(), adaptor.getDest(),
308 cstPos.getSExtValue());
311 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
312 adaptor.getPosition());
317 struct VectorInsertStridedSliceOpConvert final
322 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
324 Value srcVector = adaptor.getOperands().front();
325 Value dstVector = adaptor.getOperands().back();
332 if (isa<spirv::ScalarType>(srcVector.
getType())) {
333 assert(!isa<spirv::ScalarType>(dstVector.
getType()));
335 insertOp, dstVector.
getType(), srcVector, dstVector,
340 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
341 uint64_t insertSize =
342 cast<VectorType>(srcVector.
getType()).getNumElements();
345 std::iota(indices.begin(), indices.end(), 0);
346 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
350 insertOp, dstVector.getType(), dstVector, srcVector,
358 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
360 int numElements =
static_cast<int>(srcVectorType.getDimSize(0));
362 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
365 for (
int i = 0; i < numElements; ++i) {
366 values.push_back(rewriter.
create<spirv::CompositeExtractOp>(
367 loc, srcVectorType.getElementType(), adaptor.getVector(),
370 if (
Value acc = adaptor.getAcc())
371 values.push_back(acc);
376 struct ReductionRewriteInfo {
382 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
388 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
389 if (!srcVectorType || srcVectorType.getRank() != 1)
393 extractAllElements(op, adaptor, srcVectorType, rewriter);
395 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
398 template <
typename SPIRVUMaxOp,
typename SPIRVUMinOp,
typename SPIRVSMaxOp,
399 typename SPIRVSMinOp>
404 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
407 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
408 if (
failed(reductionInfo))
411 auto [resultType, extractedElements] = *reductionInfo;
413 Value result = extractedElements.front();
414 for (
Value next : llvm::drop_begin(extractedElements)) {
415 switch (reduceOp.getKind()) {
417 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
418 case vector::CombiningKind::kind: \
419 if (llvm::isa<IntegerType>(resultType)) { \
420 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
422 assert(llvm::isa<FloatType>(resultType)); \
423 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
427 #define INT_OR_FLOAT_CASE(kind, fop) \
428 case vector::CombiningKind::kind: \
429 result = rewriter.create<fop>(loc, resultType, result, next); \
439 case vector::CombiningKind::AND:
440 case vector::CombiningKind::OR:
441 case vector::CombiningKind::XOR:
446 #undef INT_AND_FLOAT_CASE
447 #undef INT_OR_FLOAT_CASE
455 template <
typename SPIRVFMaxOp,
typename SPIRVFMinOp>
456 struct VectorReductionFloatMinMax final
461 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
464 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
465 if (
failed(reductionInfo))
468 auto [resultType, extractedElements] = *reductionInfo;
470 Value result = extractedElements.front();
471 for (
Value next : llvm::drop_begin(extractedElements)) {
472 switch (reduceOp.getKind()) {
474 #define INT_OR_FLOAT_CASE(kind, fop) \
475 case vector::CombiningKind::kind: \
476 result = rewriter.create<fop>(loc, resultType, result, next); \
487 #undef INT_OR_FLOAT_CASE
500 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
502 Type dstType = getTypeConverter()->convertType(op.getType());
505 if (isa<spirv::ScalarType>(dstType)) {
506 rewriter.
replaceOp(op, adaptor.getInput());
508 auto dstVecType = cast<VectorType>(dstType);
518 struct VectorShuffleOpConvert final
523 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
525 auto oldResultType = shuffleOp.getResultVectorType();
526 Type newResultType = getTypeConverter()->convertType(oldResultType);
529 "unsupported result vector type");
532 shuffleOp.getMask(), [](
Attribute attr) -> int32_t {
533 return cast<IntegerAttr>(attr).getValue().getZExtValue();
536 auto oldV1Type = shuffleOp.getV1VectorType();
537 auto oldV2Type = shuffleOp.getV2VectorType();
540 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
542 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
550 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
552 if (
auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
553 return rewriter.
create<spirv::CompositeExtractOp>(loc, scalarOrVec,
556 assert(idx == 0 &&
"Invalid scalar element index");
560 int32_t numV1Elems = oldV1Type.getNumElements();
562 for (
auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
563 Value vec = adaptor.getV1();
564 int32_t elementIdx = shuffleIdx;
565 if (elementIdx >= numV1Elems) {
566 vec = adaptor.getV2();
567 elementIdx -= numV1Elems;
570 newOperand = getElementAtIdx(vec, elementIdx);
574 shuffleOp, newResultType, newOperands);
580 struct VectorLoadOpConverter final
585 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
587 auto memrefType = loadOp.getMemRefType();
589 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
592 loadOp,
"expected spirv.storage_class memory space");
594 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
595 auto loc = loadOp.getLoc();
598 adaptor.getIndices(), loc, rewriter);
601 loadOp,
"failed to get memref element pointer");
603 spirv::StorageClass storageClass = attr.getValue();
604 auto vectorType = loadOp.getVectorType();
606 Value castedAccessChain =
607 rewriter.
create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
615 struct VectorStoreOpConverter final
620 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
622 auto memrefType = storeOp.getMemRefType();
624 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
627 storeOp,
"expected spirv.storage_class memory space");
629 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
630 auto loc = storeOp.getLoc();
633 adaptor.getIndices(), loc, rewriter);
636 storeOp,
"failed to get memref element pointer");
638 spirv::StorageClass storageClass = attr.getValue();
639 auto vectorType = storeOp.getVectorType();
641 Value castedAccessChain =
642 rewriter.
create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
644 adaptor.getValueToStore());
650 struct VectorReductionToIntDotProd final
656 if (op.getKind() != vector::CombiningKind::ADD)
659 auto resultType = dyn_cast<IntegerType>(op.getType());
664 if (!llvm::is_contained({32, 64}, resultBitwidth))
667 VectorType inVecTy = op.getSourceVectorType();
668 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
669 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
672 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
675 op,
"reduction operand is not 'arith.muli'");
677 if (
succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
678 spirv::SDotAccSatOp,
false>(op, mul, rewriter)))
681 if (
succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
682 spirv::UDotAccSatOp,
false>(op, mul, rewriter)))
685 if (
succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
686 spirv::SUDotAccSatOp,
false>(op, mul, rewriter)))
689 if (
succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
690 spirv::SUDotAccSatOp,
true>(op, mul, rewriter)))
697 template <
typename LhsExtensionOp,
typename RhsExtensionOp,
typename DotOp,
698 typename DotAccOp,
bool SwapOperands>
699 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
701 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
704 Value lhsIn = lhs.getIn();
705 auto lhsInType = cast<VectorType>(lhsIn.
getType());
706 if (!lhsInType.getElementType().isInteger(8))
709 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
712 Value rhsIn = rhs.getIn();
713 auto rhsInType = cast<VectorType>(rhsIn.
getType());
714 if (!rhsInType.getElementType().isInteger(8))
717 if (op.getSourceVectorType().getNumElements() == 3) {
718 IntegerType i8Type = rewriter.
getI8Type();
722 lhsIn = rewriter.
create<spirv::CompositeConstructOp>(
724 rhsIn = rewriter.
create<spirv::CompositeConstructOp>(
731 std::swap(lhsIn, rhsIn);
733 if (
Value acc = op.getAcc()) {
745 struct VectorReductionToFPDotProd final
750 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
752 if (op.getKind() != vector::CombiningKind::ADD)
755 auto resultType = getTypeConverter()->convertType<
FloatType>(op.getType());
759 Value vec = adaptor.getVector();
760 Value acc = adaptor.getAcc();
762 auto vectorType = dyn_cast<VectorType>(vec.
getType());
764 assert(isa<FloatType>(vec.
getType()) &&
765 "Expected the vector to be scalarized");
786 rewriter.
getFloatAttr(vectorType.getElementType(), 1.0);
788 rhs = rewriter.
create<spirv::ConstantOp>(loc, vectorType, oneAttr);
793 Value res = rewriter.
create<spirv::DotOp>(loc, resultType, lhs, rhs);
795 res = rewriter.
create<spirv::FAddOp>(loc, acc, res);
803 #define CL_INT_MAX_MIN_OPS \
804 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
806 #define GL_INT_MAX_MIN_OPS \
807 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
809 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
810 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
815 VectorBitcastConvert, VectorBroadcastConvert,
816 VectorExtractElementOpConvert, VectorExtractOpConvert,
817 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
818 VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
819 VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
820 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
821 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
822 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
823 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
824 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
829 patterns.
add<VectorReductionToFPDotProd>(typeConverter, patterns.
getContext(),
835 patterns.
add<VectorReductionToIntDotProd>(patterns.
getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static int getNumBits(Type type)
Returns the number of bits for the given scalar/vector type.
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
Attributes are known-constant values of operations.
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static PointerType get(Type pointeeType, StorageClass storageClass)
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...