MLIR 22.0.0git
TensorToSPIRV.cpp
Go to the documentation of this file.
1//===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Tensor dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/AffineMap.h"
19
20#define DEBUG_TYPE "tensor-to-spirv-pattern"
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Operation conversion
26//===----------------------------------------------------------------------===//
27
28namespace {
29
30/// Converts tensor.extract into loading using access chains from SPIR-V local
31/// variables.
32class TensorExtractPattern final
33 : public OpConversionPattern<tensor::ExtractOp> {
34public:
35 TensorExtractPattern(const TypeConverter &typeConverter, MLIRContext *context,
36 int64_t threshold, PatternBenefit benefit = 1)
37 : OpConversionPattern(typeConverter, context, benefit),
38 byteCountThreshold(threshold) {}
39
40 LogicalResult
41 matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter) const override {
43 auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
44
45 if (!isa<spirv::ScalarType>(tensorType.getElementType()))
46 return rewriter.notifyMatchFailure(extractOp, "unsupported type");
47 if (!tensorType.hasStaticShape())
48 return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
49
50 if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
51 byteCountThreshold * 8)
52 return rewriter.notifyMatchFailure(extractOp,
53 "exceeding byte count threshold");
54
55 Location loc = extractOp.getLoc();
56
57 int64_t rank = tensorType.getRank();
58 SmallVector<int64_t, 4> strides(rank, 1);
59 for (int i = rank - 2; i >= 0; --i) {
60 strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
61 }
62
63 Type varType = spirv::PointerType::get(adaptor.getTensor().getType(),
64 spirv::StorageClass::Function);
65
66 spirv::VariableOp varOp;
67 if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
68 // We could use the initializer directly; but certain driver compilers
69 // have bugs dealing with that. So for now, use spirv.Store for
70 // initialization.
71 varOp = spirv::VariableOp::create(rewriter, loc, varType,
72 spirv::StorageClass::Function,
73 /*initializer=*/nullptr);
74 spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor());
75 } else {
76 // Need to store the value to the local variable. It's questionable
77 // whether we want to support such case though.
78 return failure();
79 }
80
81 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
82 auto indexType = typeConverter.getIndexType();
83
84 Value index = spirv::linearizeIndex(adaptor.getIndices(), strides,
85 /*offset=*/0, indexType, loc, rewriter);
86 auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp, index);
87
88 rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
89
90 return success();
91 }
92
93private:
94 int64_t byteCountThreshold;
95};
96
97} // namespace
98
99//===----------------------------------------------------------------------===//
100// Pattern population
101//===----------------------------------------------------------------------===//
102
104 const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold,
106 patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
107 byteCountThreshold);
108}
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static PointerType get(Type pointeeType, StorageClass storageClass)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Include the generated interface declarations.
void populateTensorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating tensor ops to SPIR-V ops.
const FrozenRewritePatternSet & patterns