34 if (!newShape.empty())
35 return input.getType();
43 return input.getType().clone(shape);
56 bool inputIsStatic = inputType.hasStaticShape();
57 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
60 auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
68 return ShapedType::kDynamic;
72 int64_t totalSizeNoPlaceholder = -std::accumulate(
73 newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
76 if (totalSizeNoPlaceholder == 0)
81 return totalSize / totalSizeNoPlaceholder;
84 bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
89 if (!inputIsStatic && resultIsStatic)
90 resultShape[0] = ShapedType::kDynamic;
95 assert(!inputIsStatic || resultIsStatic);
98 return inputType.
clone(resultShape);
107 if (lhsShape.empty() || rhsShape.empty())
110 if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
111 return lhsType.
clone({ShapedType::kDynamic});
114 unsigned currLhsDim = 0, currRhsDim = 0;
115 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
116 int64_t rhsSize = rhsShape[currRhsDim];
117 int64_t lhsSize = lhsShape[currLhsDim];
118 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
119 currRhsDim < rhsShape.size()) {
120 if (lhsSize < rhsSize) {
122 if (currLhsDim < lhsShape.size()) {
123 lhsSize *= lhsShape[currLhsDim];
127 if (currRhsDim < rhsShape.size()) {
128 rhsSize *= rhsShape[currRhsDim];
132 if (lhsSize == rhsSize) {
133 intermediateShape.push_back(lhsSize);
141 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
142 assert(lhsShape[currLhsDim] == 1);
144 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
145 assert(rhsShape[currRhsDim] == 1);
148 return lhsType.
clone(intermediateShape);
152 createReassociationMapForCollapse(
OpBuilder &builder,
Type srcType,
Type dstType) {
153 auto srcShape = cast<TensorType>(srcType).getShape();
154 auto dstShape = cast<TensorType>(dstType).getShape();
156 if (srcShape.empty() || dstShape.empty())
159 if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
160 assert(dstShape.size() == 1);
162 for (
auto i : llvm::seq<int64_t>(srcShape.size()))
168 unsigned currSrcDim = 0, currDstDim = 0;
169 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
170 int64_t dstSize = dstShape[currDstDim];
171 int64_t srcSize = srcShape[currSrcDim];
172 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
173 reassociationMap[currDstDim].push_back(
175 srcSize *= srcShape[currSrcDim];
177 if (srcSize == dstSize) {
178 reassociationMap[currDstDim].push_back(
182 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
183 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
184 reassociationMap[currDstDim].push_back(
195 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
196 return reassociationMap;
203 auto reassociationMap =
204 createReassociationMapForCollapse(builder, input.
getType(), resultType);
205 return builder.
createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
213 auto reassociationMap =
214 createReassociationMapForCollapse(builder, resultType, input.
getType());
215 return builder.
createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
224 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
226 auto loc = reshape.getLoc();
227 auto resultType = cast_if_present<ShapedType>(
228 getTypeConverter()->convertType(reshape.getType()));
230 return rewriter.notifyMatchFailure(reshape.getLoc(),
231 "could not convert result type");
233 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
235 return rewriter.notifyMatchFailure(reshape.getLoc(),
236 "expected input type to be tensor");
238 auto newShape = reshape.getNewShape();
241 auto inputType = inferReshapeInputType(input, newShape);
242 auto expandedType = inferReshapeExpandedType(inputType, newShape);
243 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
246 auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
249 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
250 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
253 auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
254 rewriter.replaceOp(reshape, result);
264 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
267 Value input = adaptor.getInput1();
268 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269 if (llvm::isa<UnrankedTensorType>(resultType))
273 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
277 int64_t size = i.value();
278 size_t index = i.index();
279 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
280 if (!ShapedType::isDynamic(sizes.back()))
283 auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
284 auto offset = rewriter.create<arith::ConstantOp>(
285 loc, rewriter.getIndexAttr(starts[index]));
286 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
289 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
290 sliceOp.getLoc(), sliceOp.getType(), input,
ValueRange({}), dynSizes,
291 ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
292 rewriter.getDenseI64ArrayAttr(sizes),
293 rewriter.getDenseI64ArrayAttr(strides));
295 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
305 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
307 auto loc = padOp.getLoc();
308 auto input = padOp.getInput1();
309 auto padding = padOp.getPadding();
311 ShapedType inputTy = cast<ShapedType>(input.getType());
312 Type elementTy = inputTy.getElementType();
313 int64_t rank = inputTy.getRank();
319 if (padOp.getPadConst()) {
320 padConstant = rewriter.createOrFold<tensor::ExtractOp>(
323 TypedAttr constantAttr;
324 if (isa<FloatType>(elementTy)) {
325 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
326 }
else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
327 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
328 }
else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
329 int64_t value = padOp.getQuantizationInfo()->getInputZp();
330 constantAttr = rewriter.getIntegerAttr(elementTy, value);
333 padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
337 return rewriter.notifyMatchFailure(
338 padOp,
"tosa.pad was unable to determine the pad constant value.");
342 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
344 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
349 lowValues.reserve(rank);
350 highValues.reserve(rank);
352 for (
int i = 0; i < rank; i++) {
353 Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
354 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
355 loc, padding,
ValueRange({inputIndex, lowIndex}));
356 Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
357 loc, padding,
ValueRange({inputIndex, highIndex}));
359 lowVal = rewriter.createOrFold<arith::IndexCastOp>(
360 loc, rewriter.getIndexType(), lowVal);
361 highVal = rewriter.createOrFold<arith::IndexCastOp>(
362 loc, rewriter.getIndexType(), highVal);
364 lowValues.push_back(lowVal);
365 highValues.push_back(highVal);
368 auto newPadOp = rewriter.create<tensor::PadOp>(
369 loc, padOp.getType(), input, lowValues, highValues, padConstant);
371 rewriter.replaceOp(padOp, newPadOp.getResult());
380 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
382 auto resultType = dyn_cast<RankedTensorType>(op.getType());
385 int axis = op.getAxis();
388 int64_t rank = resultType.getRank();
400 axisOffsets.push_back(sizes[axis]);
402 for (
auto arg : adaptor.getOperands().drop_front()) {
403 auto size = rewriter.
createOrFold<tensor::DimOp>(loc, arg, axisValue);
407 rewriter.
createOrFold<arith::AddIOp>(loc, currentOffset, size);
410 sizes[axis] = axisOffsets.back();
417 for (int64_t i = 0; i < rank; ++i) {
418 if (resultType.isDynamicDim(i)) {
425 loc, resultType.getShape(), resultType.getElementType(), dynDims);
427 for (
auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
429 offsets[axis] = offset;
431 loc, arg, result, offsets, sizes, strides);
443 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.