MLIR 22.0.0git
TosaNarrowI64ToI32.cpp
Go to the documentation of this file.
1//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass narrows TOSA operations with 64-bit integer tensor types to
10// 32-bit integer tensor types. This can be useful for backends that do not
11// support the EXT-INT64 extension of TOSA. The pass has two options:
12//
13// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
14// regardless or whether the narrowing is safe. This option may lead to
15// data loss if not used carefully.
16// - convert-function-boundaries - If enabled, the pass will convert function
17// I/O types as well. Otherwise casts will be inserted at the I/O
18// boundaries.
19//
20//===----------------------------------------------------------------------===//
21
23
26#include "mlir/IR/Verifier.h"
27#include "mlir/Pass/Pass.h"
28
29namespace mlir {
30namespace tosa {
31#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
32#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
33} // namespace tosa
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::tosa;
38
39namespace {
40
41LogicalResult convertGenericOp(Operation *op, ValueRange operands,
42 ConversionPatternRewriter &rewriter,
43 const TypeConverter *typeConverter) {
44 // Convert types of results
45 SmallVector<Type, 4> newResults;
46 if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
47 return failure();
48
49 // Create a new operation state
50 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
51 newResults, {}, op->getSuccessors());
52
53 for (const NamedAttribute &namedAttribute : op->getAttrs()) {
54 const Attribute attribute = namedAttribute.getValue();
55
56 // Convert integer attribute type
57 if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
58 const std::optional<Attribute> convertedAttribute =
59 typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
60 state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
61 continue;
62 }
63
64 if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
65 Type type = typeAttr.getValue();
66 const std::optional<Attribute> convertedAttribute =
67 typeConverter->convertTypeAttribute(type, attribute);
68 if (!convertedAttribute)
69 return rewriter.notifyMatchFailure(op,
70 "Failed to convert type attribute.");
71 state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
72 continue;
73 }
74
75 if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
76 const Type type = denseElementsAttr.getType();
77 const std::optional<Attribute> convertedAttribute =
78 typeConverter->convertTypeAttribute(type, denseElementsAttr);
79 if (!convertedAttribute)
80 return rewriter.notifyMatchFailure(
81 op, "Failed to convert dense elements attribute.");
82 state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
83 continue;
84 }
85
86 state.addAttribute(namedAttribute.getName(), attribute);
87 }
88
89 for (Region &region : op->getRegions()) {
90 Region *newRegion = state.addRegion();
91 rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
92 if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
93 return failure();
94 }
95
96 Operation *newOp = rewriter.create(state);
97 rewriter.replaceOp(op, newOp->getResults());
98 return success();
99}
100
101// ===========================
102// Aggressive rewrite patterns
103// ===========================
104
105class ConvertGenericOp : public ConversionPattern {
106public:
107 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
108 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
109
110 LogicalResult
111 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112 ConversionPatternRewriter &rewriter) const final {
113 if (!isa<tosa::TosaOp>(op))
114 return rewriter.notifyMatchFailure(
115 op,
116 "Support for operations other than TOSA has not been implemented.");
117
118 return convertGenericOp(op, operands, rewriter, typeConverter);
119 }
120};
121
122// ===============================
123// Bounds checked rewrite patterns
124// ===============================
125
126class ConvertArgMaxOpWithBoundsChecking
127 : public OpConversionPattern<tosa::ArgMaxOp> {
128 using OpConversionPattern::OpConversionPattern;
129
130 LogicalResult
131 matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const final {
133 // Output type can be narrowed based on the size of the axis dimension
134 const int32_t axis = op.getAxis();
135 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
136 if (!inputType || !inputType.isStaticDim(axis))
137 return rewriter.notifyMatchFailure(
138 op, "Requires a static axis dimension for bounds checking.");
139 const int64_t axisDim = inputType.getDimSize(axis);
140 if (axisDim >= std::numeric_limits<int32_t>::max())
141 return rewriter.notifyMatchFailure(
142 op, "Axis dimension is too large to narrow safely.");
143
144 const Type resultType = op.getOutput().getType();
145 const Type newResultType = typeConverter->convertType(resultType);
146 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
147 adaptor.getInput(), axis);
148 return success();
149 }
150};
151
152class ConvertCastOpWithBoundsChecking
153 : public OpConversionPattern<tosa::CastOp> {
154 using OpConversionPattern::OpConversionPattern;
155
156 LogicalResult
157 matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const final {
159 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
160 const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
161 if (!inputType || !resultType)
162 return failure();
163
164 const auto elementInputIntType =
165 dyn_cast<IntegerType>(inputType.getElementType());
166 const auto elementResultIntType =
167 dyn_cast<IntegerType>(resultType.getElementType());
168 if (elementInputIntType && elementResultIntType &&
169 elementInputIntType.getWidth() > elementResultIntType.getWidth())
170 return rewriter.notifyMatchFailure(
171 op, "Narrowing cast may lead to data loss.");
172
173 rewriter.replaceOpWithNewOp<tosa::CastOp>(
174 op, typeConverter->convertType(resultType), adaptor.getInput());
175 return success();
176 }
177};
178
179class ConvertClampOpWithBoundsChecking
180 : public OpConversionPattern<tosa::ClampOp> {
181 using OpConversionPattern::OpConversionPattern;
182
183 LogicalResult
184 matchAndRewrite(tosa::ClampOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const final {
186 const auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
187 const auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
188 if (!minAttr || !maxAttr)
189 return failure();
190
191 const int64_t min = minAttr.getInt();
192 const int64_t max = maxAttr.getInt();
193
194 if (min < std::numeric_limits<int32_t>::min() ||
195 max > std::numeric_limits<int32_t>::max())
196 return rewriter.notifyMatchFailure(
197 op, "Clamp bounds exceed int32 range. Narrowing cast may lead to "
198 "data loss.");
199
200 const Type resultType = op.getOutput().getType();
201 const Type newResultType = typeConverter->convertType(resultType);
202
203 const IntegerType int32Type = IntegerType::get(rewriter.getContext(), 32);
204 const IntegerAttr newMinAttr =
205 rewriter.getIntegerAttr(int32Type, static_cast<int32_t>(min));
206 const IntegerAttr newMaxAttr =
207 rewriter.getIntegerAttr(int32Type, static_cast<int32_t>(max));
208 rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
209 adaptor.getInput(), newMinAttr,
210 newMaxAttr, op.getNanModeAttr());
211 return success();
212 }
213};
214
215template <typename OpTy>
216class ConvertTypedOp : public OpConversionPattern<OpTy> {
217 using OpConversionPattern<OpTy>::OpConversionPattern;
218
219 LogicalResult
220 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
221 ConversionPatternRewriter &rewriter) const final {
222 return convertGenericOp(op, adaptor.getOperands(), rewriter,
223 this->getTypeConverter());
224 }
225};
226
227struct TosaNarrowI64ToI32
228 : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
229public:
230 explicit TosaNarrowI64ToI32() = default;
231 explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
232 : TosaNarrowI64ToI32() {
233 this->aggressiveRewrite = options.aggressiveRewrite;
234 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
235 }
236
237 void runOnOperation() override {
238 MLIRContext *context = &getContext();
239
240 TypeConverter typeConverter;
241 typeConverter.addConversion([](Type type) -> Type { return type; });
242 typeConverter.addConversion([](IntegerType type) -> Type {
243 if (!type.isInteger(64))
244 return type;
245 return IntegerType::get(type.getContext(), 32);
246 });
247 typeConverter.addConversion(
248 [&typeConverter](RankedTensorType type) -> Type {
249 const Type elementType = type.getElementType();
250 if (!elementType.isInteger(64))
251 return type;
252 return RankedTensorType::get(type.getShape(),
253 typeConverter.convertType(elementType));
254 });
255
256 const auto materializeCast = [](OpBuilder &builder, Type resultType,
257 ValueRange inputs, Location loc) -> Value {
258 if (inputs.size() != 1)
259 return Value();
260 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
261 };
262 typeConverter.addSourceMaterialization(materializeCast);
263 typeConverter.addTargetMaterialization(materializeCast);
264
265 typeConverter.addTypeAttributeConversion(
266 [](IntegerType type, IntegerAttr attribute) -> Attribute {
267 const APInt value = attribute.getValue().truncSSat(32);
268 return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
269 value);
270 });
271 typeConverter.addTypeAttributeConversion(
272 [&typeConverter](ShapedType type,
273 DenseIntElementsAttr attr) -> Attribute {
274 const ShapedType newType =
275 cast<ShapedType>(typeConverter.convertType(type));
276 const auto oldElementType = cast<IntegerType>(type.getElementType());
277 const auto newElementType =
278 cast<IntegerType>(newType.getElementType());
279 if (oldElementType.getWidth() == newElementType.getWidth())
280 return attr;
281
282 DenseElementsAttr mapped =
283 attr.mapValues(newElementType, [&](const APInt &v) {
284 return v.truncSSat(newElementType.getWidth());
285 });
286 return mapped;
287 });
288
289 ConversionTarget target(*context);
290 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
291 [&typeConverter](Operation *op) {
292 return typeConverter.isLegal(op->getResultTypes()) &&
293 typeConverter.isLegal(op->getOperandTypes());
294 });
295 if (convertFunctionBoundaries) {
296 target.addDynamicallyLegalOp<func::FuncOp>(
297 [&typeConverter](func::FuncOp op) {
298 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
299 typeConverter.isLegal(&op.getBody());
300 });
301 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
302 const FunctionType funcType =
303 op->getParentOfType<func::FuncOp>().getFunctionType();
304 return llvm::equal(op.getOperandTypes(), funcType.getResults());
305 });
306 } else {
307 target.addDynamicallyLegalOp<func::FuncOp>(
308 [](func::FuncOp op) { return true; });
309 target.addDynamicallyLegalOp<func::ReturnOp>(
310 [](func::ReturnOp op) { return true; });
311 }
312
313 RewritePatternSet patterns(context);
314 if (convertFunctionBoundaries) {
315 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
316 patterns, typeConverter);
318 }
319 if (aggressiveRewrite) {
320 patterns.add<ConvertGenericOp>(typeConverter, context);
321 } else {
322 // Tensor
323 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
324 // Activation functions
325 patterns.add<ConvertClampOpWithBoundsChecking>(typeConverter, context);
326 // Data layout
327 patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
328 patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
329 patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
330 patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
331 patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
332 patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
333 patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
334 patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
335 // Type conversion
336 patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
337 // Controlflow
338 patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
339 patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
340 }
341
342 if (failed(
343 applyFullConversion(getOperation(), target, std::move(patterns))))
344 signalPassFailure();
345 }
346};
347
348} // namespace
return success()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
SuccessorRange getSuccessors()
Definition Operation.h:703
result_range getResults()
Definition Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.