MLIR  22.0.0git
TosaToTensor.cpp
Go to the documentation of this file.
1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/PatternMatch.h"
21 
22 #include <numeric>
23 
24 using namespace mlir;
25 using namespace tosa;
26 
27 namespace {
28 
29 // Infer the type to which the input of a 'tosa.reshape' op must be cast when
30 // lowered.
31 TensorType inferReshapeInputType(TypedValue<TensorType> input,
32  ArrayRef<int64_t> newShape) {
33  // No need to cast input for non-empty target shape
34  if (!newShape.empty())
35  return input.getType();
36 
37  // The input type must be cast into a tensor with the same rank and all static
38  // dimensions set to 1. This prevents the generation of a
39  // tensor.collapse_shape op that converts a dynamically shaped tensor into a
40  // 0D tensor. While such construct is not incorrect on its own, bufferization
41  // cannot properly handle it at the moment, so we avoid it.
42  SmallVector<int64_t> shape(input.getType().getRank(), 1);
43  return input.getType().clone(shape);
44 }
45 
46 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
47 // pair emitted for a 'tosa.reshape' op.
48 TensorType inferReshapeExpandedType(TensorType inputType,
49  ArrayRef<int64_t> newShape) {
50  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51  // with just '{}', as it will invoke the incorrect overload.
52  if (newShape.empty())
53  return inputType.clone(ArrayRef<int64_t>{});
54 
55  // Check if the input is static, and if so, get its total size
56  bool inputIsStatic = inputType.hasStaticShape();
57  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58 
59  // Compute result shape
60  auto resultShape =
61  llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
62  // If this is not a placeholder, do not change it.
63  if (size >= 0)
64  return size;
65 
66  // If we do not know the total size of the tensor, keep this dimension
67  // dynamic in the result shape.
68  if (!inputIsStatic)
69  return ShapedType::kDynamic;
70 
71  // Calculate the product of all elements in 'newShape' except for the -1
72  // placeholder, which we discard by negating the result.
73  int64_t totalSizeNoPlaceholder = -std::accumulate(
74  newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
75 
76  // If there is a 0 component in 'newShape', resolve the placeholder as
77  // 0.
78  if (totalSizeNoPlaceholder == 0)
79  return 0;
80 
81  // Resolve the placeholder as the quotient between the total tensor size
82  // and the product of all other sizes.
83  return totalSize / totalSizeNoPlaceholder;
84  });
85 
86  bool resultIsStatic = ShapedType::isStaticShape(resultShape);
87 
88  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
89  // shaped input from being reshaped into a statically shaped result. We may
90  // simply turn the first result dimension dynamic to address this.
91  if (!inputIsStatic && resultIsStatic)
92  resultShape[0] = ShapedType::kDynamic;
93 
94  // The 'tensor.expand_shape' op also forbids a statically shaped input from
95  // being reshaped into a dynamically shaped result, but the placeholder
96  // inference algorithm above guarantees that this will never be the case.
97  assert(!inputIsStatic || resultIsStatic);
98 
99  // Create result type
100  return inputType.clone(resultShape);
101 }
102 
103 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
104 // pair emitted for a 'tosa.reshape' op.
105 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
106  auto lhsShape = lhsType.getShape();
107  auto rhsShape = rhsType.getShape();
108 
109  if (lhsShape.empty() || rhsShape.empty())
110  return lhsType.clone(ArrayRef<int64_t>{});
111 
112  if (ShapedType::isDynamicShape(lhsShape) ||
113  ShapedType::isDynamicShape(rhsShape))
114  return lhsType.clone({ShapedType::kDynamic});
115 
116  SmallVector<int64_t> intermediateShape;
117  unsigned currLhsDim = 0, currRhsDim = 0;
118  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
119  int64_t rhsSize = rhsShape[currRhsDim];
120  int64_t lhsSize = lhsShape[currLhsDim];
121  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
122  currRhsDim < rhsShape.size()) {
123  if (lhsSize < rhsSize) {
124  currLhsDim++;
125  if (currLhsDim < lhsShape.size()) {
126  lhsSize *= lhsShape[currLhsDim];
127  }
128  } else {
129  currRhsDim++;
130  if (currRhsDim < rhsShape.size()) {
131  rhsSize *= rhsShape[currRhsDim];
132  }
133  }
134  }
135  if (lhsSize == rhsSize) {
136  intermediateShape.push_back(lhsSize);
137  }
138  currRhsDim++;
139  currLhsDim++;
140  }
141 
142  // Static shapes are guaranteed to be compatible by the op verifier, so all
143  // leftover dimensions should be 1.
144  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
145  assert(lhsShape[currLhsDim] == 1);
146  }
147  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
148  assert(rhsShape[currRhsDim] == 1);
149  }
150 
151  return lhsType.clone(intermediateShape);
152 }
153 
155 createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
156  Type dstType) {
157  auto srcShape = cast<TensorType>(srcType).getShape();
158  auto dstShape = cast<TensorType>(dstType).getShape();
159 
160  if (srcShape.empty() || dstShape.empty())
161  return {};
162 
163  if (ShapedType::isDynamicShape(srcShape) ||
164  ShapedType::isDynamicShape(dstShape)) {
165  assert(dstShape.size() == 1);
167  for (auto i : llvm::seq<int64_t>(srcShape.size()))
168  exprs.push_back(builder.getAffineDimExpr(i));
169  return {exprs};
170  }
171 
172  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
173  unsigned currSrcDim = 0, currDstDim = 0;
174  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
175  int64_t dstSize = dstShape[currDstDim];
176  int64_t srcSize = srcShape[currSrcDim];
177  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
178  reassociationMap[currDstDim].push_back(
179  builder.getAffineDimExpr(currSrcDim++));
180  srcSize *= srcShape[currSrcDim];
181  }
182  if (srcSize == dstSize) {
183  reassociationMap[currDstDim].push_back(
184  builder.getAffineDimExpr(currSrcDim++));
185  // If the next dim in collapsedShape is not 1, treat subsequent dims in
186  // expandedShape which are 1 to be collapsed.
187  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
188  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
189  reassociationMap[currDstDim].push_back(
190  builder.getAffineDimExpr(currSrcDim++));
191  }
192  }
193  }
194  currDstDim++;
195  }
196 
197  // If the source and target shapes are compatible, both iterators must have
198  // reached the end. This condition is guaranteed by the op verifier for
199  // static shapes.
200  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
201  return reassociationMap;
202 }
203 
204 // Create a tensor.collapse_shape op that reshapes the input into the given
205 // result type.
206 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
207  Value input) {
208  auto reassociationMap =
209  createReassociationMapForCollapse(builder, input.getType(), resultType);
210  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
211  reassociationMap);
212 }
213 
214 // Create a tensor.expand_shape op that reshapes the input into the given result
215 // type.
216 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
217  Value input) {
218  auto reassociationMap =
219  createReassociationMapForCollapse(builder, resultType, input.getType());
220  return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
221  reassociationMap);
222 }
223 
224 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
225 public:
227 
228  LogicalResult
229  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
230  ConversionPatternRewriter &rewriter) const final {
231  auto loc = reshape.getLoc();
232  auto resultType =
233  getTypeConverter()->convertType<ShapedType>(reshape.getType());
234  if (!resultType) {
235  return rewriter.notifyMatchFailure(reshape.getLoc(),
236  "could not convert result type");
237  }
238  auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
239  if (!input) {
240  return rewriter.notifyMatchFailure(reshape.getLoc(),
241  "expected input type to be tensor");
242  }
243 
245  if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
246  newShape)) {
247  return failure();
248  }
249 
250  // Infer all intermediate types
251  auto inputType = inferReshapeInputType(input, newShape);
252  auto expandedType = inferReshapeExpandedType(inputType, newShape);
253  auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
254 
255  // Cast input if needed
256  auto castInput =
257  rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
258 
259  // Emit collaspe-expand pair
260  auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
261  auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
262 
263  // Cast to final result type if needed
264  auto result =
265  rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
266  rewriter.replaceOp(reshape, result);
267  return success();
268  }
269 };
270 
271 class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
272 public:
274 
275  LogicalResult
276  matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
277  ConversionPatternRewriter &rewriter) const final {
278  Location loc = sliceOp.getLoc();
279  Value input = adaptor.getInput1();
280  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
281  if (llvm::isa<UnrankedTensorType>(resultType))
282  return failure();
283 
284  ElementsAttr startElems;
285  ElementsAttr sizeElems;
286 
287  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
288  return rewriter.notifyMatchFailure(
289  sliceOp, "start of slice must be a static ranked shape");
290 
291  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
292  return rewriter.notifyMatchFailure(
293  sliceOp, "size of slice must be a static ranked shape");
294 
295  llvm::SmallVector<int64_t> sliceStarts =
296  llvm::to_vector(startElems.getValues<int64_t>());
297  llvm::SmallVector<int64_t> sliceSizes =
298  llvm::to_vector(sizeElems.getValues<int64_t>());
299 
300  SmallVector<int64_t> strides, sizes;
301  strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
302 
303  SmallVector<Value> dynSizes;
304  for (const auto &i : llvm::enumerate(sliceSizes)) {
305  int64_t size = i.value();
306  size_t index = i.index();
307  sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
308  if (ShapedType::isStatic(sizes.back()))
309  continue;
310 
311  auto dim = tensor::DimOp::create(rewriter, loc, input, index);
312  auto offset = arith::ConstantOp::create(
313  rewriter, loc, rewriter.getIndexAttr(sliceStarts[index]));
314  dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset));
315  }
316 
317  auto newSliceOp = tensor::ExtractSliceOp::create(
318  rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}),
319  dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
320  rewriter.getDenseI64ArrayAttr(sizes),
321  rewriter.getDenseI64ArrayAttr(strides));
322 
323  // Remove const_shape ops when it no longer has use point.
324  Operation *startConstShape = sliceOp.getStart().getDefiningOp();
325  if (startConstShape->getResult(0).hasOneUse())
326  rewriter.eraseOp(startConstShape);
327 
328  Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
329  if (sizeConstShape->getResult(0).hasOneUse())
330  rewriter.eraseOp(sizeConstShape);
331 
332  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
333  return success();
334  }
335 };
336 
337 class PadConverter : public OpConversionPattern<tosa::PadOp> {
338 public:
340 
341  LogicalResult
342  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const final {
344  auto loc = padOp.getLoc();
345  auto input = padOp.getInput1();
346 
347  ElementsAttr paddingElems;
348  if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
349  return rewriter.notifyMatchFailure(
350  padOp, "padding must be a static shape value");
351  }
352  llvm::SmallVector<int64_t> paddingVals;
353  for (auto idx : paddingElems.getValues<IntegerAttr>()) {
354  paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
355  }
356 
357  ShapedType inputTy = cast<ShapedType>(input.getType());
358  int64_t rank = inputTy.getRank();
359 
360  // Setup the default constantAttr.
361 
362  Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
363  loc, padOp.getPadConst(),
364  ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)}));
365 
366  if (!padConstant) {
367  return rewriter.notifyMatchFailure(
368  padOp, "tosa.pad was unable to determine the pad constant value.");
369  }
370 
372  SmallVector<OpFoldResult, 3> highValues;
373 
374  lowValues.reserve(rank);
375  highValues.reserve(rank);
376 
377  for (int i = 0; i < rank; i++) {
378  Value lowVal = arith::ConstantOp::create(
379  rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i]));
380  Value highVal = arith::ConstantOp::create(
381  rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
382  lowValues.push_back(lowVal);
383  highValues.push_back(highVal);
384  }
385 
386  auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input,
387  lowValues, highValues, padConstant);
388 
389  rewriter.replaceOp(padOp, newPadOp.getResult());
390  return success();
391  }
392 };
393 
394 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
396 
397  LogicalResult
398  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
399  ConversionPatternRewriter &rewriter) const override {
400  auto resultType = dyn_cast<RankedTensorType>(op.getType());
401 
402  Location loc = op.getLoc();
403  int axis = op.getAxis();
404  Value axisValue =
405  arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis));
406  int64_t rank = resultType.getRank();
407 
408  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
409  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
411  tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
412 
413  // Pre-compute the offsets along the axis dimension.
414  // The axisOffsets will be of size rank + 1, where the last value
415  // will hold the total size of the tensor along the 'axis' dimension.
416  SmallVector<OpFoldResult> axisOffsets;
417  axisOffsets.push_back(rewriter.getIndexAttr(0));
418  axisOffsets.push_back(sizes[axis]);
419 
420  for (auto arg : adaptor.getOperands().drop_front()) {
421  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
422  auto currentOffset =
423  getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
424  auto total =
425  rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
426  axisOffsets.push_back(getAsOpFoldResult(total));
427  }
428  sizes[axis] = axisOffsets.back();
429 
430  // Compute the dynamic sizes of the tensor.empty operation.
431  // This is based off of the specified result type of the tosa.concat
432  // operation, since we don't want to change the result type of the operation
433  // during the conversion.
434  SmallVector<Value> dynDims;
435  for (int64_t i = 0; i < rank; ++i) {
436  if (resultType.isDynamicDim(i)) {
437  dynDims.push_back(
438  getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
439  }
440  }
441 
442  Value result =
443  tensor::EmptyOp::create(rewriter, loc, resultType.getShape(),
444  resultType.getElementType(), dynDims);
445 
446  for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
447  auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
448  offsets[axis] = offset;
449  result = rewriter.createOrFold<tensor::InsertSliceOp>(
450  loc, arg, result, offsets, sizes, strides);
451  }
452  rewriter.replaceOp(op, result);
453  return success();
454  }
455 };
456 
457 } // namespace
458 
460  const TypeConverter &converter, RewritePatternSet *patterns) {
461  patterns
462  ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
463  converter, patterns->getContext());
464 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369