31#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
32#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
42 ConversionPatternRewriter &rewriter,
54 const Attribute attribute = namedAttribute.getValue();
57 if (
const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
58 const std::optional<Attribute> convertedAttribute =
59 typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
60 state.
addAttribute(namedAttribute.getName(), convertedAttribute.value());
64 if (
const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
65 Type type = typeAttr.getValue();
66 const std::optional<Attribute> convertedAttribute =
67 typeConverter->convertTypeAttribute(type, attribute);
68 if (!convertedAttribute)
69 return rewriter.notifyMatchFailure(op,
70 "Failed to convert type attribute.");
71 state.
addAttribute(namedAttribute.getName(), convertedAttribute.value());
75 if (
const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
76 const Type type = denseElementsAttr.getType();
77 const std::optional<Attribute> convertedAttribute =
78 typeConverter->convertTypeAttribute(type, denseElementsAttr);
79 if (!convertedAttribute)
80 return rewriter.notifyMatchFailure(
81 op,
"Failed to convert dense elements attribute.");
82 state.
addAttribute(namedAttribute.getName(), convertedAttribute.value());
91 rewriter.inlineRegionBefore(region, *newRegion, newRegion->
begin());
92 if (
failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
96 Operation *newOp = rewriter.create(state);
107 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
108 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
111 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112 ConversionPatternRewriter &rewriter)
const final {
113 if (!isa<tosa::TosaOp>(op))
114 return rewriter.notifyMatchFailure(
116 "Support for operations other than TOSA has not been implemented.");
118 return convertGenericOp(op, operands, rewriter, typeConverter);
126class ConvertArgMaxOpWithBoundsChecking
127 :
public OpConversionPattern<tosa::ArgMaxOp> {
128 using OpConversionPattern::OpConversionPattern;
131 matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const final {
134 const int32_t axis = op.getAxis();
135 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
136 if (!inputType || !inputType.isStaticDim(axis))
137 return rewriter.notifyMatchFailure(
138 op,
"Requires a static axis dimension for bounds checking.");
139 const int64_t axisDim = inputType.getDimSize(axis);
140 if (axisDim >= std::numeric_limits<int32_t>::max())
141 return rewriter.notifyMatchFailure(
142 op,
"Axis dimension is too large to narrow safely.");
144 const Type resultType = op.getOutput().getType();
145 const Type newResultType = typeConverter->convertType(resultType);
146 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
147 adaptor.getInput(), axis);
152class ConvertCastOpWithBoundsChecking
153 :
public OpConversionPattern<tosa::CastOp> {
154 using OpConversionPattern::OpConversionPattern;
157 matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter)
const final {
159 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
161 if (!inputType || !resultType)
164 const auto elementInputIntType =
165 dyn_cast<IntegerType>(inputType.getElementType());
166 const auto elementResultIntType =
167 dyn_cast<IntegerType>(resultType.getElementType());
168 if (elementInputIntType && elementResultIntType &&
169 elementInputIntType.getWidth() > elementResultIntType.getWidth())
170 return rewriter.notifyMatchFailure(
171 op,
"Narrowing cast may lead to data loss.");
173 rewriter.replaceOpWithNewOp<tosa::CastOp>(
174 op, typeConverter->convertType(resultType), adaptor.getInput());
179template <
typename OpTy>
180class ConvertTypedOp :
public OpConversionPattern<OpTy> {
181 using OpConversionPattern<OpTy>::OpConversionPattern;
184 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const final {
186 return convertGenericOp(op, adaptor.getOperands(), rewriter,
187 this->getTypeConverter());
191struct TosaNarrowI64ToI32
194 explicit TosaNarrowI64ToI32() =
default;
195 explicit TosaNarrowI64ToI32(
const TosaNarrowI64ToI32PassOptions &
options)
196 : TosaNarrowI64ToI32() {
197 this->aggressiveRewrite =
options.aggressiveRewrite;
198 this->convertFunctionBoundaries =
options.convertFunctionBoundaries;
201 void runOnOperation()
override {
204 TypeConverter typeConverter;
205 typeConverter.addConversion([](Type type) -> Type {
return type; });
206 typeConverter.addConversion([](IntegerType type) -> Type {
207 if (!type.isInteger(64))
209 return IntegerType::get(type.getContext(), 32);
211 typeConverter.addConversion(
212 [&typeConverter](RankedTensorType type) -> Type {
213 const Type elementType = type.getElementType();
216 return RankedTensorType::get(type.getShape(),
217 typeConverter.convertType(elementType));
220 const auto materializeCast = [](OpBuilder &builder, Type resultType,
222 if (inputs.size() != 1)
224 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
226 typeConverter.addSourceMaterialization(materializeCast);
227 typeConverter.addTargetMaterialization(materializeCast);
229 typeConverter.addTypeAttributeConversion(
230 [](IntegerType type, IntegerAttr attribute) -> Attribute {
231 const APInt value = attribute.getValue().truncSSat(32);
232 return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
235 typeConverter.addTypeAttributeConversion(
236 [&typeConverter](ShapedType type,
237 DenseIntElementsAttr attr) -> Attribute {
238 const ShapedType newType =
239 cast<ShapedType>(typeConverter.convertType(type));
240 const auto oldElementType = cast<IntegerType>(type.getElementType());
241 const auto newElementType =
242 cast<IntegerType>(newType.getElementType());
243 if (oldElementType.getWidth() == newElementType.getWidth())
246 DenseElementsAttr mapped =
247 attr.
mapValues(newElementType, [&](
const APInt &v) {
248 return v.truncSSat(newElementType.getWidth());
253 ConversionTarget
target(*context);
254 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
255 [&typeConverter](Operation *op) {
259 if (convertFunctionBoundaries) {
260 target.addDynamicallyLegalOp<func::FuncOp>(
261 [&typeConverter](func::FuncOp op) {
262 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
263 typeConverter.isLegal(&op.getBody());
265 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
266 const FunctionType funcType =
271 target.addDynamicallyLegalOp<func::FuncOp>(
272 [](func::FuncOp op) {
return true; });
273 target.addDynamicallyLegalOp<func::ReturnOp>(
274 [](func::ReturnOp op) {
return true; });
277 RewritePatternSet
patterns(context);
278 if (convertFunctionBoundaries) {
279 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
283 if (aggressiveRewrite) {
284 patterns.add<ConvertGenericOp>(typeConverter, context);
287 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
289 patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
290 patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
291 patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
292 patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
293 patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
294 patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
295 patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
296 patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
298 patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
300 patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
301 patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
305 applyFullConversion(getOperation(),
target, std::move(
patterns))))
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
NamedAttribute represents a combination of a name and an Attribute value.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.