34 if (!newShape.empty())
35 return input.getType();
43 return input.getType().clone(shape);
56 bool inputIsStatic = inputType.hasStaticShape();
57 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
61 llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
69 return ShapedType::kDynamic;
73 int64_t totalSizeNoPlaceholder = -std::accumulate(
74 newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
78 if (totalSizeNoPlaceholder == 0)
83 return totalSize / totalSizeNoPlaceholder;
86 bool resultIsStatic = ShapedType::isStaticShape(resultShape);
91 if (!inputIsStatic && resultIsStatic)
92 resultShape[0] = ShapedType::kDynamic;
97 assert(!inputIsStatic || resultIsStatic);
100 return inputType.
clone(resultShape);
109 if (lhsShape.empty() || rhsShape.empty())
112 if (ShapedType::isDynamicShape(lhsShape) ||
113 ShapedType::isDynamicShape(rhsShape))
114 return lhsType.
clone({ShapedType::kDynamic});
117 unsigned currLhsDim = 0, currRhsDim = 0;
118 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
119 int64_t rhsSize = rhsShape[currRhsDim];
120 int64_t lhsSize = lhsShape[currLhsDim];
121 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
122 currRhsDim < rhsShape.size()) {
123 if (lhsSize < rhsSize) {
125 if (currLhsDim < lhsShape.size()) {
126 lhsSize *= lhsShape[currLhsDim];
130 if (currRhsDim < rhsShape.size()) {
131 rhsSize *= rhsShape[currRhsDim];
135 if (lhsSize == rhsSize) {
136 intermediateShape.push_back(lhsSize);
144 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
145 assert(lhsShape[currLhsDim] == 1);
147 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
148 assert(rhsShape[currRhsDim] == 1);
151 return lhsType.
clone(intermediateShape);
155 createReassociationMapForCollapse(
OpBuilder &builder,
Type srcType,
157 auto srcShape = cast<TensorType>(srcType).getShape();
158 auto dstShape = cast<TensorType>(dstType).getShape();
160 if (srcShape.empty() || dstShape.empty())
163 if (ShapedType::isDynamicShape(srcShape) ||
164 ShapedType::isDynamicShape(dstShape)) {
165 assert(dstShape.size() == 1);
167 for (
auto i : llvm::seq<int64_t>(srcShape.size()))
173 unsigned currSrcDim = 0, currDstDim = 0;
174 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
175 int64_t dstSize = dstShape[currDstDim];
176 int64_t srcSize = srcShape[currSrcDim];
177 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
178 reassociationMap[currDstDim].push_back(
180 srcSize *= srcShape[currSrcDim];
182 if (srcSize == dstSize) {
183 reassociationMap[currDstDim].push_back(
187 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
188 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
189 reassociationMap[currDstDim].push_back(
200 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
201 return reassociationMap;
208 auto reassociationMap =
209 createReassociationMapForCollapse(builder, input.
getType(), resultType);
210 return builder.
createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
218 auto reassociationMap =
219 createReassociationMapForCollapse(builder, resultType, input.
getType());
220 return builder.
createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
229 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
231 auto loc = reshape.getLoc();
233 getTypeConverter()->convertType<ShapedType>(reshape.getType());
235 return rewriter.notifyMatchFailure(reshape.getLoc(),
236 "could not convert result type");
238 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
240 return rewriter.notifyMatchFailure(reshape.getLoc(),
241 "expected input type to be tensor");
251 auto inputType = inferReshapeInputType(input, newShape);
252 auto expandedType = inferReshapeExpandedType(inputType, newShape);
253 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
257 rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
260 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
261 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
265 rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
266 rewriter.replaceOp(reshape, result);
276 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
279 Value input = adaptor.getInput1();
280 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
281 if (llvm::isa<UnrankedTensorType>(resultType))
284 ElementsAttr startElems;
285 ElementsAttr sizeElems;
288 return rewriter.notifyMatchFailure(
289 sliceOp,
"start of slice must be a static ranked shape");
292 return rewriter.notifyMatchFailure(
293 sliceOp,
"size of slice must be a static ranked shape");
296 llvm::to_vector(startElems.getValues<int64_t>());
298 llvm::to_vector(sizeElems.getValues<int64_t>());
301 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
305 int64_t size = i.value();
306 size_t index = i.index();
307 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
308 if (ShapedType::isStatic(sizes.back()))
311 auto dim = tensor::DimOp::create(rewriter, loc, input, index);
312 auto offset = arith::ConstantOp::create(
313 rewriter, loc, rewriter.getIndexAttr(sliceStarts[index]));
314 dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset));
317 auto newSliceOp = tensor::ExtractSliceOp::create(
318 rewriter, sliceOp.getLoc(), sliceOp.getType(), input,
ValueRange({}),
319 dynSizes,
ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
320 rewriter.getDenseI64ArrayAttr(sizes),
321 rewriter.getDenseI64ArrayAttr(strides));
324 Operation *startConstShape = sliceOp.getStart().getDefiningOp();
326 rewriter.eraseOp(startConstShape);
328 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
330 rewriter.eraseOp(sizeConstShape);
332 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
342 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
344 auto loc = padOp.getLoc();
345 auto input = padOp.getInput1();
347 ElementsAttr paddingElems;
349 return rewriter.notifyMatchFailure(
350 padOp,
"padding must be a static shape value");
353 for (
auto idx : paddingElems.getValues<IntegerAttr>()) {
354 paddingVals.push_back(
static_cast<int64_t
>(idx.getInt()));
357 ShapedType inputTy = cast<ShapedType>(input.
getType());
358 int64_t rank = inputTy.getRank();
362 Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
363 loc, padOp.getPadConst(),
367 return rewriter.notifyMatchFailure(
368 padOp,
"tosa.pad was unable to determine the pad constant value.");
374 lowValues.reserve(rank);
375 highValues.reserve(rank);
377 for (
int i = 0; i < rank; i++) {
378 Value lowVal = arith::ConstantOp::create(
379 rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i]));
380 Value highVal = arith::ConstantOp::create(
381 rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
382 lowValues.push_back(lowVal);
383 highValues.push_back(highVal);
386 auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input,
387 lowValues, highValues, padConstant);
389 rewriter.replaceOp(padOp, newPadOp.getResult());
398 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
400 auto resultType = dyn_cast<RankedTensorType>(op.getType());
403 int axis = op.getAxis();
405 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(axis));
406 int64_t rank = resultType.getRank();
418 axisOffsets.push_back(sizes[axis]);
420 for (
auto arg : adaptor.getOperands().drop_front()) {
421 auto size = rewriter.
createOrFold<tensor::DimOp>(loc, arg, axisValue);
425 rewriter.
createOrFold<arith::AddIOp>(loc, currentOffset, size);
428 sizes[axis] = axisOffsets.back();
435 for (int64_t i = 0; i < rank; ++i) {
436 if (resultType.isDynamicDim(i)) {
443 tensor::EmptyOp::create(rewriter, loc, resultType.getShape(),
444 resultType.getElementType(), dynDims);
446 for (
auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
448 offsets[axis] = offset;
450 loc, arg, result, offsets, sizes, strides);
462 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.