35 if (!newShape.empty())
36 return input.getType();
44 return input.getType().clone(shape);
57 bool inputIsStatic = inputType.hasStaticShape();
58 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
62 llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
70 return ShapedType::kDynamic;
74 int64_t totalSizeNoPlaceholder = -std::accumulate(
75 newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
79 if (totalSizeNoPlaceholder == 0)
84 return totalSize / totalSizeNoPlaceholder;
87 bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
92 if (!inputIsStatic && resultIsStatic)
93 resultShape[0] = ShapedType::kDynamic;
98 assert(!inputIsStatic || resultIsStatic);
101 return inputType.
clone(resultShape);
110 if (lhsShape.empty() || rhsShape.empty())
113 if (ShapedType::isDynamicShape(lhsShape) ||
114 ShapedType::isDynamicShape(rhsShape))
115 return lhsType.
clone({ShapedType::kDynamic});
118 unsigned currLhsDim = 0, currRhsDim = 0;
119 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
120 int64_t rhsSize = rhsShape[currRhsDim];
121 int64_t lhsSize = lhsShape[currLhsDim];
122 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
123 currRhsDim < rhsShape.size()) {
124 if (lhsSize < rhsSize) {
126 if (currLhsDim < lhsShape.size()) {
127 lhsSize *= lhsShape[currLhsDim];
131 if (currRhsDim < rhsShape.size()) {
132 rhsSize *= rhsShape[currRhsDim];
136 if (lhsSize == rhsSize) {
137 intermediateShape.push_back(lhsSize);
145 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
146 assert(lhsShape[currLhsDim] == 1);
148 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
149 assert(rhsShape[currRhsDim] == 1);
152 return lhsType.
clone(intermediateShape);
156 createReassociationMapForCollapse(
OpBuilder &builder,
Type srcType,
158 auto srcShape = cast<TensorType>(srcType).getShape();
159 auto dstShape = cast<TensorType>(dstType).getShape();
161 if (srcShape.empty() || dstShape.empty())
164 if (ShapedType::isDynamicShape(srcShape) ||
165 ShapedType::isDynamicShape(dstShape)) {
166 assert(dstShape.size() == 1);
168 for (
auto i : llvm::seq<int64_t>(srcShape.size()))
174 unsigned currSrcDim = 0, currDstDim = 0;
175 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
176 int64_t dstSize = dstShape[currDstDim];
177 int64_t srcSize = srcShape[currSrcDim];
178 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
179 reassociationMap[currDstDim].push_back(
181 srcSize *= srcShape[currSrcDim];
183 if (srcSize == dstSize) {
184 reassociationMap[currDstDim].push_back(
188 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
189 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
190 reassociationMap[currDstDim].push_back(
201 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
202 return reassociationMap;
209 auto reassociationMap =
210 createReassociationMapForCollapse(builder, input.
getType(), resultType);
211 return builder.
createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
219 auto reassociationMap =
220 createReassociationMapForCollapse(builder, resultType, input.
getType());
221 return builder.
createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
230 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
232 auto loc = reshape.getLoc();
233 auto resultType = cast_if_present<ShapedType>(
234 getTypeConverter()->convertType(reshape.getType()));
236 return rewriter.notifyMatchFailure(reshape.getLoc(),
237 "could not convert result type");
239 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
241 return rewriter.notifyMatchFailure(reshape.getLoc(),
242 "expected input type to be tensor");
252 auto inputType = inferReshapeInputType(input, newShape);
253 auto expandedType = inferReshapeExpandedType(inputType, newShape);
254 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
258 rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
261 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
262 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
266 rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
267 rewriter.replaceOp(reshape, result);
277 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
280 Value input = adaptor.getInput1();
281 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
282 if (llvm::isa<UnrankedTensorType>(resultType))
285 ElementsAttr startElems;
286 ElementsAttr sizeElems;
289 return rewriter.notifyMatchFailure(
290 sliceOp,
"start of slice must be a static ranked shape");
293 return rewriter.notifyMatchFailure(
294 sliceOp,
"size of slice must be a static ranked shape");
297 llvm::to_vector(startElems.getValues<int64_t>());
299 llvm::to_vector(sizeElems.getValues<int64_t>());
302 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
306 int64_t size = i.value();
307 size_t index = i.index();
308 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
309 if (!ShapedType::isDynamic(sizes.back()))
312 auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
313 auto offset = rewriter.create<arith::ConstantOp>(
314 loc, rewriter.getIndexAttr(sliceStarts[index]));
315 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
318 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
319 sliceOp.getLoc(), sliceOp.getType(), input,
ValueRange({}), dynSizes,
320 ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
321 rewriter.getDenseI64ArrayAttr(sizes),
322 rewriter.getDenseI64ArrayAttr(strides));
324 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
327 Operation *startConstShape = sliceOp.getStart().getDefiningOp();
329 rewriter.eraseOp(startConstShape);
331 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
333 rewriter.eraseOp(sizeConstShape);
344 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
346 auto loc = padOp.getLoc();
347 auto input = padOp.getInput1();
349 ElementsAttr paddingElems;
351 return rewriter.notifyMatchFailure(
352 padOp,
"padding must be a static shape value");
355 for (
auto idx : paddingElems.getValues<IntegerAttr>()) {
356 paddingVals.push_back(
static_cast<int64_t
>(idx.getInt()));
359 ShapedType inputTy = cast<ShapedType>(input.getType());
360 int64_t rank = inputTy.getRank();
364 Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
368 return rewriter.notifyMatchFailure(
369 padOp,
"tosa.pad was unable to determine the pad constant value.");
375 lowValues.reserve(rank);
376 highValues.reserve(rank);
378 for (
int i = 0; i < rank; i++) {
379 Value lowVal = rewriter.create<arith::ConstantOp>(
380 loc, rewriter.getIndexAttr(paddingVals[2 * i]));
381 Value highVal = rewriter.create<arith::ConstantOp>(
382 loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
383 lowValues.push_back(lowVal);
384 highValues.push_back(highVal);
387 auto newPadOp = rewriter.create<tensor::PadOp>(
388 loc, padOp.getType(), input, lowValues, highValues, padConstant);
390 rewriter.replaceOp(padOp, newPadOp.getResult());
399 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
401 auto resultType = dyn_cast<RankedTensorType>(op.getType());
404 int axis = op.getAxis();
407 int64_t rank = resultType.getRank();
419 axisOffsets.push_back(sizes[axis]);
421 for (
auto arg : adaptor.getOperands().drop_front()) {
422 auto size = rewriter.
createOrFold<tensor::DimOp>(loc, arg, axisValue);
426 rewriter.
createOrFold<arith::AddIOp>(loc, currentOffset, size);
429 sizes[axis] = axisOffsets.back();
436 for (int64_t i = 0; i < rank; ++i) {
437 if (resultType.isDynamicDim(i)) {
444 loc, resultType.getShape(), resultType.getElementType(), dynDims);
446 for (
auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
448 offsets[axis] = offset;
450 loc, arg, result, offsets, sizes, strides);
462 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.