34 if (!newShape.empty())
35 return input.getType();
43 return input.getType().clone(shape);
56 bool inputIsStatic = inputType.hasStaticShape();
57 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
60 bool resultIsStatic =
true;
61 auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
69 resultIsStatic =
false;
70 return ShapedType::kDynamic;
75 int64_t totalSizeNoPlaceholder = -std::accumulate(
76 newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
79 if (totalSizeNoPlaceholder == 0)
84 return totalSize / totalSizeNoPlaceholder;
90 if (!inputIsStatic && resultIsStatic)
91 resultShape[0] = ShapedType::kDynamic;
96 assert(!inputIsStatic || resultIsStatic);
99 return inputType.
clone(resultShape);
108 if (lhsShape.empty() || rhsShape.empty())
111 if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
112 return lhsType.
clone({ShapedType::kDynamic});
115 unsigned currLhsDim = 0, currRhsDim = 0;
116 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
117 int64_t rhsSize = rhsShape[currRhsDim];
118 int64_t lhsSize = lhsShape[currLhsDim];
119 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
120 currRhsDim < rhsShape.size()) {
121 if (lhsSize < rhsSize) {
123 if (currLhsDim < lhsShape.size()) {
124 lhsSize *= lhsShape[currLhsDim];
128 if (currRhsDim < rhsShape.size()) {
129 rhsSize *= rhsShape[currRhsDim];
133 if (lhsSize == rhsSize) {
134 intermediateShape.push_back(lhsSize);
142 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
143 assert(lhsShape[currLhsDim] == 1);
145 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
146 assert(rhsShape[currRhsDim] == 1);
149 return lhsType.
clone(intermediateShape);
153 createReassociationMapForCollapse(
OpBuilder &builder,
Type srcType,
Type dstType) {
154 auto srcShape = cast<TensorType>(srcType).getShape();
155 auto dstShape = cast<TensorType>(dstType).getShape();
157 if (srcShape.empty() || dstShape.empty())
160 if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
161 assert(dstShape.size() == 1);
163 for (
auto i : llvm::seq<int64_t>(srcShape.size()))
169 unsigned currSrcDim = 0, currDstDim = 0;
170 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
171 int64_t dstSize = dstShape[currDstDim];
172 int64_t srcSize = srcShape[currSrcDim];
173 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
174 reassociationMap[currDstDim].push_back(
176 srcSize *= srcShape[currSrcDim];
178 if (srcSize == dstSize) {
179 reassociationMap[currDstDim].push_back(
183 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
184 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
185 reassociationMap[currDstDim].push_back(
196 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
197 return reassociationMap;
204 auto reassociationMap =
205 createReassociationMapForCollapse(builder, input.
getType(), resultType);
206 return builder.
createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
214 auto reassociationMap =
215 createReassociationMapForCollapse(builder, resultType, input.
getType());
216 return builder.
createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
225 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
227 auto loc = reshape.getLoc();
228 auto resultType = reshape.getResult().getType();
229 auto input = reshape.getInput1();
230 auto newShape = reshape.getNewShape();
233 auto inputType = inferReshapeInputType(input, newShape);
234 auto expandedType = inferReshapeExpandedType(inputType, newShape);
235 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
238 auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
241 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
242 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
245 auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
246 rewriter.replaceOp(reshape, result);
256 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
259 Value input = adaptor.getInput();
260 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
261 if (llvm::isa<UnrankedTensorType>(resultType))
265 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
269 int64_t size = i.value();
270 size_t index = i.index();
271 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
272 if (!ShapedType::isDynamic(sizes.back()))
275 auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
276 auto offset = rewriter.create<arith::ConstantOp>(
277 loc, rewriter.getIndexAttr(starts[index]));
278 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
281 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
282 sliceOp.getLoc(), sliceOp.getType(), input,
ValueRange({}), dynSizes,
283 ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
284 rewriter.getDenseI64ArrayAttr(sizes),
285 rewriter.getDenseI64ArrayAttr(strides));
287 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
298 auto loc = padOp.getLoc();
299 auto input = padOp.getInput1();
300 auto padding = padOp.getPadding();
302 ShapedType inputTy = cast<ShapedType>(input.getType());
303 Type elementTy = inputTy.getElementType();
304 int64_t rank = inputTy.getRank();
310 if (padOp.getPadConst()) {
311 padConstant = rewriter.createOrFold<tensor::ExtractOp>(
314 TypedAttr constantAttr;
315 if (isa<FloatType>(elementTy)) {
316 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
317 }
else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
318 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
319 }
else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
320 int64_t value = padOp.getQuantizationInfo()->getInputZp();
321 constantAttr = rewriter.getIntegerAttr(elementTy, value);
324 padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
328 return rewriter.notifyMatchFailure(
329 padOp,
"tosa.pad was unable to determine the pad constant value.");
333 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
335 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
340 lowValues.reserve(rank);
341 highValues.reserve(rank);
343 for (
int i = 0; i < rank; i++) {
344 Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
345 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
346 loc, padding,
ValueRange({inputIndex, lowIndex}));
347 Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
348 loc, padding,
ValueRange({inputIndex, highIndex}));
350 lowVal = rewriter.createOrFold<arith::IndexCastOp>(
351 loc, rewriter.getIndexType(), lowVal);
352 highVal = rewriter.createOrFold<arith::IndexCastOp>(
353 loc, rewriter.getIndexType(), highVal);
355 lowValues.push_back(lowVal);
356 highValues.push_back(highVal);
359 auto newPadOp = rewriter.create<tensor::PadOp>(
360 loc, padOp.getType(), input, lowValues, highValues, padConstant);
362 rewriter.replaceOp(padOp, newPadOp.getResult());
371 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
373 auto resultType = dyn_cast<RankedTensorType>(op.getType());
376 int axis = op.getAxis();
379 int64_t rank = resultType.getRank();
391 axisOffsets.push_back(sizes[axis]);
393 for (
auto arg : adaptor.getOperands().drop_front()) {
394 auto size = rewriter.
createOrFold<tensor::DimOp>(loc, arg, axisValue);
398 rewriter.
createOrFold<arith::AddIOp>(loc, currentOffset, size);
401 sizes[axis] = axisOffsets.back();
408 for (int64_t i = 0; i < rank; ++i) {
409 if (resultType.isDynamicDim(i)) {
416 loc, resultType.getShape(), resultType.getElementType(), dynDims);
418 for (
auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
420 offsets[axis] = offset;
422 loc, arg, result, offsets, sizes, strides);
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...