21#include "llvm/ADT/STLExtras.h"
35 if (!newShape.empty())
36 return input.getType();
44 return input.getType().clone(
shape);
57 bool inputIsStatic = inputType.hasStaticShape();
58 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
70 return ShapedType::kDynamic;
74 int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape);
78 if (totalSizeNoPlaceholder == 0)
83 return totalSize / totalSizeNoPlaceholder;
86 bool resultIsStatic = ShapedType::isStaticShape(resultShape);
91 if (!inputIsStatic && resultIsStatic)
92 resultShape[0] = ShapedType::kDynamic;
97 assert(!inputIsStatic || resultIsStatic);
100 return inputType.
clone(resultShape);
109 if (lhsShape.empty() || rhsShape.empty())
112 if (ShapedType::isDynamicShape(lhsShape) ||
113 ShapedType::isDynamicShape(rhsShape))
114 return lhsType.
clone({ShapedType::kDynamic});
117 unsigned currLhsDim = 0, currRhsDim = 0;
118 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
119 int64_t rhsSize = rhsShape[currRhsDim];
120 int64_t lhsSize = lhsShape[currLhsDim];
121 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
122 currRhsDim < rhsShape.size()) {
123 if (lhsSize < rhsSize) {
125 if (currLhsDim < lhsShape.size()) {
126 lhsSize *= lhsShape[currLhsDim];
130 if (currRhsDim < rhsShape.size()) {
131 rhsSize *= rhsShape[currRhsDim];
135 if (lhsSize == rhsSize) {
136 intermediateShape.push_back(lhsSize);
144 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
145 assert(lhsShape[currLhsDim] == 1);
147 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
148 assert(rhsShape[currRhsDim] == 1);
151 return lhsType.
clone(intermediateShape);
155createReassociationMapForCollapse(
OpBuilder &builder,
Type srcType,
157 auto srcShape = cast<TensorType>(srcType).getShape();
158 auto dstShape = cast<TensorType>(dstType).getShape();
160 if (srcShape.empty() || dstShape.empty())
163 if (ShapedType::isDynamicShape(srcShape) ||
164 ShapedType::isDynamicShape(dstShape)) {
165 assert(dstShape.size() == 1);
167 for (
auto i : llvm::seq<int64_t>(srcShape.size()))
173 unsigned currSrcDim = 0, currDstDim = 0;
174 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
175 int64_t dstSize = dstShape[currDstDim];
176 int64_t srcSize = srcShape[currSrcDim];
177 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
178 reassociationMap[currDstDim].push_back(
180 srcSize *= srcShape[currSrcDim];
182 if (srcSize == dstSize) {
183 reassociationMap[currDstDim].push_back(
187 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
188 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
189 reassociationMap[currDstDim].push_back(
200 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
201 return reassociationMap;
208 auto reassociationMap =
209 createReassociationMapForCollapse(builder, input.
getType(), resultType);
210 return builder.
createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
218 auto reassociationMap =
219 createReassociationMapForCollapse(builder, resultType, input.
getType());
220 return builder.
createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
224class ReshapeConverter :
public OpConversionPattern<tosa::ReshapeOp> {
226 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
229 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter)
const final {
231 auto loc = reshape.getLoc();
233 getTypeConverter()->convertType<ShapedType>(reshape.getType());
235 return rewriter.notifyMatchFailure(reshape.getLoc(),
236 "could not convert result type");
238 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
240 return rewriter.notifyMatchFailure(reshape.getLoc(),
241 "expected input type to be tensor");
244 llvm::SmallVector<int64_t> newShape;
251 auto inputType = inferReshapeInputType(input, newShape);
252 auto expandedType = inferReshapeExpandedType(inputType, newShape);
253 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
257 rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
260 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
261 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
265 rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
266 rewriter.replaceOp(reshape,
result);
271class SliceConverter :
public OpConversionPattern<tosa::SliceOp> {
273 using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
276 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const final {
278 Location loc = sliceOp.getLoc();
279 Value input = adaptor.getInput1();
280 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
281 if (llvm::isa<UnrankedTensorType>(resultType))
284 ElementsAttr startElems;
285 ElementsAttr sizeElems;
288 return rewriter.notifyMatchFailure(
289 sliceOp,
"start of slice must be a static ranked shape");
292 return rewriter.notifyMatchFailure(
293 sliceOp,
"size of slice must be a static ranked shape");
295 llvm::SmallVector<int64_t> sliceStarts =
296 llvm::to_vector(startElems.getValues<int64_t>());
297 llvm::SmallVector<int64_t> sliceSizes =
298 llvm::to_vector(sizeElems.getValues<int64_t>());
300 SmallVector<int64_t> strides, sizes;
301 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
303 SmallVector<Value> dynSizes;
304 for (
const auto &i : llvm::enumerate(sliceSizes)) {
305 int64_t size = i.value();
306 size_t index = i.index();
307 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
308 if (ShapedType::isStatic(sizes.back()))
311 auto dim = tensor::DimOp::create(rewriter, loc, input, index);
312 auto offset = arith::ConstantOp::create(
313 rewriter, loc, rewriter.getIndexAttr(sliceStarts[index]));
314 dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset));
317 auto newSliceOp = tensor::ExtractSliceOp::create(
318 rewriter, sliceOp.getLoc(), sliceOp.getType(), input,
ValueRange({}),
319 dynSizes,
ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
320 rewriter.getDenseI64ArrayAttr(sizes),
321 rewriter.getDenseI64ArrayAttr(strides));
324 Operation *startConstShape = sliceOp.getStart().getDefiningOp();
325 if (startConstShape->getResult(0).hasOneUse())
326 rewriter.eraseOp(startConstShape);
328 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
330 rewriter.eraseOp(sizeConstShape);
332 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
337class PadConverter :
public OpConversionPattern<tosa::PadOp> {
339 using OpConversionPattern::OpConversionPattern;
342 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter)
const final {
344 auto loc = padOp.getLoc();
345 auto input = padOp.getInput1();
347 ElementsAttr paddingElems;
349 return rewriter.notifyMatchFailure(
350 padOp,
"padding must be a static shape value");
352 llvm::SmallVector<int64_t> paddingVals;
353 for (
auto idx : paddingElems.getValues<IntegerAttr>()) {
354 paddingVals.push_back(
static_cast<int64_t
>(idx.getInt()));
357 ShapedType inputTy = cast<ShapedType>(input.
getType());
358 int64_t rank = inputTy.getRank();
362 Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
363 loc, padOp.getPadConst(),
367 return rewriter.notifyMatchFailure(
368 padOp,
"tosa.pad was unable to determine the pad constant value.");
371 SmallVector<OpFoldResult, 3> lowValues;
372 SmallVector<OpFoldResult, 3> highValues;
374 lowValues.reserve(rank);
375 highValues.reserve(rank);
377 for (
int i = 0; i < rank; i++) {
378 Value lowVal = arith::ConstantOp::create(
379 rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i]));
380 Value highVal = arith::ConstantOp::create(
381 rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
382 lowValues.push_back(lowVal);
383 highValues.push_back(highVal);
386 auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input,
387 lowValues, highValues, padConstant);
389 rewriter.replaceOp(padOp, newPadOp.getResult());
394struct ConcatConverter :
public OpConversionPattern<tosa::ConcatOp> {
395 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
398 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter)
const override {
400 auto resultType = dyn_cast<RankedTensorType>(op.getType());
402 Location loc = op.getLoc();
403 int axis = op.getAxis();
405 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis));
406 int64_t rank = resultType.getRank();
408 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
409 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
410 SmallVector<OpFoldResult> sizes =
416 SmallVector<OpFoldResult> axisOffsets;
417 axisOffsets.push_back(rewriter.getIndexAttr(0));
418 axisOffsets.push_back(sizes[axis]);
420 for (
auto arg : adaptor.getOperands().drop_front()) {
421 auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
425 rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
428 sizes[axis] = axisOffsets.back();
434 SmallVector<Value> dynDims;
435 for (int64_t i = 0; i < rank; ++i) {
436 if (resultType.isDynamicDim(i)) {
443 tensor::EmptyOp::create(rewriter, loc, resultType.getShape(),
444 resultType.getElementType(), dynDims);
446 for (
auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
448 offsets[axis] = offset;
449 result = rewriter.createOrFold<tensor::InsertSliceOp>(
450 loc, arg,
result, offsets, sizes, strides);
452 rewriter.replaceOp(op,
result);
462 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
AffineExpr getAffineDimExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.