MLIR 22.0.0git
TosaNarrowI64ToI32.cpp
Go to the documentation of this file.
1//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass narrows TOSA operations with 64-bit integer tensor types to
10// 32-bit integer tensor types. This can be useful for backends that do not
11// support the EXT-INT64 extension of TOSA. The pass has two options:
12//
13// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
14// regardless or whether the narrowing is safe. This option may lead to
15// data loss if not used carefully.
16// - convert-function-boundaries - If enabled, the pass will convert function
17// I/O types as well. Otherwise casts will be inserted at the I/O
18// boundaries.
19//
20//===----------------------------------------------------------------------===//
21
23
26#include "mlir/IR/Verifier.h"
27#include "mlir/Pass/Pass.h"
28
29namespace mlir {
30namespace tosa {
31#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
32#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
33} // namespace tosa
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::tosa;
38
39namespace {
40
41LogicalResult convertGenericOp(Operation *op, ValueRange operands,
42 ConversionPatternRewriter &rewriter,
43 const TypeConverter *typeConverter) {
44 // Convert types of results
45 SmallVector<Type, 4> newResults;
46 if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
47 return failure();
48
49 // Create a new operation state
50 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
51 newResults, {}, op->getSuccessors());
52
53 for (const NamedAttribute &namedAttribute : op->getAttrs()) {
54 const Attribute attribute = namedAttribute.getValue();
55
56 // Convert integer attribute type
57 if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
58 const std::optional<Attribute> convertedAttribute =
59 typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
60 state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
61 continue;
62 }
63
64 if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
65 Type type = typeAttr.getValue();
66 const std::optional<Attribute> convertedAttribute =
67 typeConverter->convertTypeAttribute(type, attribute);
68 if (!convertedAttribute)
69 return rewriter.notifyMatchFailure(op,
70 "Failed to convert type attribute.");
71 state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
72 continue;
73 }
74
75 if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
76 const Type type = denseElementsAttr.getType();
77 const std::optional<Attribute> convertedAttribute =
78 typeConverter->convertTypeAttribute(type, denseElementsAttr);
79 if (!convertedAttribute)
80 return rewriter.notifyMatchFailure(
81 op, "Failed to convert dense elements attribute.");
82 state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
83 continue;
84 }
85
86 state.addAttribute(namedAttribute.getName(), attribute);
87 }
88
89 for (Region &region : op->getRegions()) {
90 Region *newRegion = state.addRegion();
91 rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
92 if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
93 return failure();
94 }
95
96 Operation *newOp = rewriter.create(state);
97 rewriter.replaceOp(op, newOp->getResults());
98 return success();
99}
100
101// ===========================
102// Aggressive rewrite patterns
103// ===========================
104
105class ConvertGenericOp : public ConversionPattern {
106public:
107 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
108 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
109
110 LogicalResult
111 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112 ConversionPatternRewriter &rewriter) const final {
113 if (!isa<tosa::TosaOp>(op))
114 return rewriter.notifyMatchFailure(
115 op,
116 "Support for operations other than TOSA has not been implemented.");
117
118 return convertGenericOp(op, operands, rewriter, typeConverter);
119 }
120};
121
122// ===============================
123// Bounds checked rewrite patterns
124// ===============================
125
126class ConvertArgMaxOpWithBoundsChecking
127 : public OpConversionPattern<tosa::ArgMaxOp> {
128 using OpConversionPattern::OpConversionPattern;
129
130 LogicalResult
131 matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const final {
133 // Output type can be narrowed based on the size of the axis dimension
134 const int32_t axis = op.getAxis();
135 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
136 if (!inputType || !inputType.isStaticDim(axis))
137 return rewriter.notifyMatchFailure(
138 op, "Requires a static axis dimension for bounds checking.");
139 const int64_t axisDim = inputType.getDimSize(axis);
140 if (axisDim >= std::numeric_limits<int32_t>::max())
141 return rewriter.notifyMatchFailure(
142 op, "Axis dimension is too large to narrow safely.");
143
144 const Type resultType = op.getOutput().getType();
145 const Type newResultType = typeConverter->convertType(resultType);
146 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
147 adaptor.getInput(), axis);
148 return success();
149 }
150};
151
152class ConvertCastOpWithBoundsChecking
153 : public OpConversionPattern<tosa::CastOp> {
154 using OpConversionPattern::OpConversionPattern;
155
156 LogicalResult
157 matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const final {
159 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
160 const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
161 if (!inputType || !resultType)
162 return failure();
163
164 const auto elementInputIntType =
165 dyn_cast<IntegerType>(inputType.getElementType());
166 const auto elementResultIntType =
167 dyn_cast<IntegerType>(resultType.getElementType());
168 if (elementInputIntType && elementResultIntType &&
169 elementInputIntType.getWidth() > elementResultIntType.getWidth())
170 return rewriter.notifyMatchFailure(
171 op, "Narrowing cast may lead to data loss.");
172
173 rewriter.replaceOpWithNewOp<tosa::CastOp>(
174 op, typeConverter->convertType(resultType), adaptor.getInput());
175 return success();
176 }
177};
178
179template <typename OpTy>
180class ConvertTypedOp : public OpConversionPattern<OpTy> {
181 using OpConversionPattern<OpTy>::OpConversionPattern;
182
183 LogicalResult
184 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
185 ConversionPatternRewriter &rewriter) const final {
186 return convertGenericOp(op, adaptor.getOperands(), rewriter,
187 this->getTypeConverter());
188 }
189};
190
191struct TosaNarrowI64ToI32
192 : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
193public:
194 explicit TosaNarrowI64ToI32() = default;
195 explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
196 : TosaNarrowI64ToI32() {
197 this->aggressiveRewrite = options.aggressiveRewrite;
198 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
199 }
200
201 void runOnOperation() override {
202 MLIRContext *context = &getContext();
203
204 TypeConverter typeConverter;
205 typeConverter.addConversion([](Type type) -> Type { return type; });
206 typeConverter.addConversion([](IntegerType type) -> Type {
207 if (!type.isInteger(64))
208 return type;
209 return IntegerType::get(type.getContext(), 32);
210 });
211 typeConverter.addConversion(
212 [&typeConverter](RankedTensorType type) -> Type {
213 const Type elementType = type.getElementType();
214 if (!elementType.isInteger(64))
215 return type;
216 return RankedTensorType::get(type.getShape(),
217 typeConverter.convertType(elementType));
218 });
219
220 const auto materializeCast = [](OpBuilder &builder, Type resultType,
221 ValueRange inputs, Location loc) -> Value {
222 if (inputs.size() != 1)
223 return Value();
224 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
225 };
226 typeConverter.addSourceMaterialization(materializeCast);
227 typeConverter.addTargetMaterialization(materializeCast);
228
229 typeConverter.addTypeAttributeConversion(
230 [](IntegerType type, IntegerAttr attribute) -> Attribute {
231 const APInt value = attribute.getValue().truncSSat(32);
232 return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
233 value);
234 });
235 typeConverter.addTypeAttributeConversion(
236 [&typeConverter](ShapedType type,
237 DenseIntElementsAttr attr) -> Attribute {
238 const ShapedType newType =
239 cast<ShapedType>(typeConverter.convertType(type));
240 const auto oldElementType = cast<IntegerType>(type.getElementType());
241 const auto newElementType =
242 cast<IntegerType>(newType.getElementType());
243 if (oldElementType.getWidth() == newElementType.getWidth())
244 return attr;
245
246 DenseElementsAttr mapped =
247 attr.mapValues(newElementType, [&](const APInt &v) {
248 return v.truncSSat(newElementType.getWidth());
249 });
250 return mapped;
251 });
252
253 ConversionTarget target(*context);
254 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
255 [&typeConverter](Operation *op) {
256 return typeConverter.isLegal(op->getResultTypes()) &&
257 typeConverter.isLegal(op->getOperandTypes());
258 });
259 if (convertFunctionBoundaries) {
260 target.addDynamicallyLegalOp<func::FuncOp>(
261 [&typeConverter](func::FuncOp op) {
262 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
263 typeConverter.isLegal(&op.getBody());
264 });
265 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
266 const FunctionType funcType =
267 op->getParentOfType<func::FuncOp>().getFunctionType();
268 return llvm::equal(op.getOperandTypes(), funcType.getResults());
269 });
270 } else {
271 target.addDynamicallyLegalOp<func::FuncOp>(
272 [](func::FuncOp op) { return true; });
273 target.addDynamicallyLegalOp<func::ReturnOp>(
274 [](func::ReturnOp op) { return true; });
275 }
276
277 RewritePatternSet patterns(context);
278 if (convertFunctionBoundaries) {
279 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
280 patterns, typeConverter);
282 }
283 if (aggressiveRewrite) {
284 patterns.add<ConvertGenericOp>(typeConverter, context);
285 } else {
286 // Tensor
287 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
288 // Data layout
289 patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
290 patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
291 patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
292 patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
293 patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
294 patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
295 patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
296 patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
297 // Type conversion
298 patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
299 // Controlflow
300 patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
301 patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
302 }
303
304 if (failed(
305 applyFullConversion(getOperation(), target, std::move(patterns))))
306 signalPassFailure();
307 }
308};
309
310} // namespace
return success()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
SuccessorRange getSuccessors()
Definition Operation.h:703
result_range getResults()
Definition Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.