20#define DEBUG_TYPE "tensor-to-spirv-pattern"
32class TensorExtractPattern final
33 :
public OpConversionPattern<tensor::ExtractOp> {
37 : OpConversionPattern(typeConverter, context, benefit),
38 byteCountThreshold(threshold) {}
41 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter)
const override {
43 auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
45 if (!isa<spirv::ScalarType>(tensorType.getElementType()))
46 return rewriter.notifyMatchFailure(extractOp,
"unsupported type");
47 if (!tensorType.hasStaticShape())
48 return rewriter.notifyMatchFailure(extractOp,
"non-static tensor");
50 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
51 byteCountThreshold * 8)
52 return rewriter.notifyMatchFailure(extractOp,
53 "exceeding byte count threshold");
57 int64_t rank = tensorType.getRank();
59 for (
int i = rank - 2; i >= 0; --i) {
60 strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
64 spirv::StorageClass::Function);
66 spirv::VariableOp varOp;
67 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
71 varOp = spirv::VariableOp::create(rewriter, loc, varType,
72 spirv::StorageClass::Function,
74 spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor());
81 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
82 auto indexType = typeConverter.getIndexType();
85 0, indexType, loc, rewriter);
86 auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp,
index);
88 rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static PointerType get(Type pointeeType, StorageClass storageClass)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Include the generated interface declarations.
void populateTensorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating tensor ops to SPIR-V ops.
const FrozenRewritePatternSet & patterns