14 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "tensor-to-spirv-pattern"
35 class TensorExtractPattern final
41 byteCountThreshold(threshold) {}
44 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
46 auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
48 if (!isa<spirv::ScalarType>(tensorType.getElementType()))
50 if (!tensorType.hasStaticShape())
53 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
54 byteCountThreshold * 8)
56 "exceeding byte count threshold");
60 int64_t rank = tensorType.getRank();
62 for (
int i = rank - 2; i >= 0; --i) {
63 strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
67 spirv::StorageClass::Function);
69 spirv::VariableOp varOp;
70 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
74 varOp = rewriter.
create<spirv::VariableOp>(loc, varType,
75 spirv::StorageClass::Function,
77 rewriter.
create<spirv::StoreOp>(loc, varOp, adaptor.getTensor());
84 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
85 auto indexType = typeConverter.getIndexType();
88 0, indexType, loc, rewriter);
89 auto acOp = rewriter.
create<spirv::AccessChainOp>(loc, varOp, index);
97 int64_t byteCountThreshold;
109 patterns.
add<TensorExtractPattern>(typeConverter, patterns.
getContext(),
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static PointerType get(Type pointeeType, StorageClass storageClass)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Include the generated interface declarations.
void populateTensorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating tensor ops to SPIR-V ops.