MLIR  21.0.0git
TosaToTensor.cpp
Go to the documentation of this file.
1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/PatternMatch.h"
22 
23 #include <numeric>
24 
25 using namespace mlir;
26 using namespace tosa;
27 
28 namespace {
29 
30 // Infer the type to which the input of a 'tosa.reshape' op must be cast when
31 // lowered.
32 TensorType inferReshapeInputType(TypedValue<TensorType> input,
33  ArrayRef<int64_t> newShape) {
34  // No need to cast input for non-empty target shape
35  if (!newShape.empty())
36  return input.getType();
37 
38  // The input type must be cast into a tensor with the same rank and all static
39  // dimensions set to 1. This prevents the generation of a
40  // tensor.collapse_shape op that converts a dynamically shaped tensor into a
41  // 0D tensor. While such construct is not incorrect on its own, bufferization
42  // cannot properly handle it at the moment, so we avoid it.
43  SmallVector<int64_t> shape(input.getType().getRank(), 1);
44  return input.getType().clone(shape);
45 }
46 
47 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
48 // pair emitted for a 'tosa.reshape' op.
49 TensorType inferReshapeExpandedType(TensorType inputType,
50  ArrayRef<int64_t> newShape) {
51  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
52  // with just '{}', as it will invoke the incorrect overload.
53  if (newShape.empty())
54  return inputType.clone(ArrayRef<int64_t>{});
55 
56  // Check if the input is static, and if so, get its total size
57  bool inputIsStatic = inputType.hasStaticShape();
58  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
59 
60  // Compute result shape
61  auto resultShape =
62  llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
63  // If this is not a placeholder, do not change it.
64  if (size >= 0)
65  return size;
66 
67  // If we do not know the total size of the tensor, keep this dimension
68  // dynamic in the result shape.
69  if (!inputIsStatic)
70  return ShapedType::kDynamic;
71 
72  // Calculate the product of all elements in 'newShape' except for the -1
73  // placeholder, which we discard by negating the result.
74  int64_t totalSizeNoPlaceholder = -std::accumulate(
75  newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
76 
77  // If there is a 0 component in 'newShape', resolve the placeholder as
78  // 0.
79  if (totalSizeNoPlaceholder == 0)
80  return 0;
81 
82  // Resolve the placeholder as the quotient between the total tensor size
83  // and the product of all other sizes.
84  return totalSize / totalSizeNoPlaceholder;
85  });
86 
87  bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
88 
89  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
90  // shaped input from being reshaped into a statically shaped result. We may
91  // simply turn the first result dimension dynamic to address this.
92  if (!inputIsStatic && resultIsStatic)
93  resultShape[0] = ShapedType::kDynamic;
94 
95  // The 'tensor.expand_shape' op also forbids a statically shaped input from
96  // being reshaped into a dynamically shaped result, but the placeholder
97  // inference algorithm above guarantees that this will never be the case.
98  assert(!inputIsStatic || resultIsStatic);
99 
100  // Create result type
101  return inputType.clone(resultShape);
102 }
103 
104 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
105 // pair emitted for a 'tosa.reshape' op.
106 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
107  auto lhsShape = lhsType.getShape();
108  auto rhsShape = rhsType.getShape();
109 
110  if (lhsShape.empty() || rhsShape.empty())
111  return lhsType.clone(ArrayRef<int64_t>{});
112 
113  if (ShapedType::isDynamicShape(lhsShape) ||
114  ShapedType::isDynamicShape(rhsShape))
115  return lhsType.clone({ShapedType::kDynamic});
116 
117  SmallVector<int64_t> intermediateShape;
118  unsigned currLhsDim = 0, currRhsDim = 0;
119  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
120  int64_t rhsSize = rhsShape[currRhsDim];
121  int64_t lhsSize = lhsShape[currLhsDim];
122  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
123  currRhsDim < rhsShape.size()) {
124  if (lhsSize < rhsSize) {
125  currLhsDim++;
126  if (currLhsDim < lhsShape.size()) {
127  lhsSize *= lhsShape[currLhsDim];
128  }
129  } else {
130  currRhsDim++;
131  if (currRhsDim < rhsShape.size()) {
132  rhsSize *= rhsShape[currRhsDim];
133  }
134  }
135  }
136  if (lhsSize == rhsSize) {
137  intermediateShape.push_back(lhsSize);
138  }
139  currRhsDim++;
140  currLhsDim++;
141  }
142 
143  // Static shapes are guaranteed to be compatible by the op verifier, so all
144  // leftover dimensions should be 1.
145  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
146  assert(lhsShape[currLhsDim] == 1);
147  }
148  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
149  assert(rhsShape[currRhsDim] == 1);
150  }
151 
152  return lhsType.clone(intermediateShape);
153 }
154 
156 createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
157  Type dstType) {
158  auto srcShape = cast<TensorType>(srcType).getShape();
159  auto dstShape = cast<TensorType>(dstType).getShape();
160 
161  if (srcShape.empty() || dstShape.empty())
162  return {};
163 
164  if (ShapedType::isDynamicShape(srcShape) ||
165  ShapedType::isDynamicShape(dstShape)) {
166  assert(dstShape.size() == 1);
168  for (auto i : llvm::seq<int64_t>(srcShape.size()))
169  exprs.push_back(builder.getAffineDimExpr(i));
170  return {exprs};
171  }
172 
173  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
174  unsigned currSrcDim = 0, currDstDim = 0;
175  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
176  int64_t dstSize = dstShape[currDstDim];
177  int64_t srcSize = srcShape[currSrcDim];
178  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
179  reassociationMap[currDstDim].push_back(
180  builder.getAffineDimExpr(currSrcDim++));
181  srcSize *= srcShape[currSrcDim];
182  }
183  if (srcSize == dstSize) {
184  reassociationMap[currDstDim].push_back(
185  builder.getAffineDimExpr(currSrcDim++));
186  // If the next dim in collapsedShape is not 1, treat subsequent dims in
187  // expandedShape which are 1 to be collapsed.
188  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
189  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
190  reassociationMap[currDstDim].push_back(
191  builder.getAffineDimExpr(currSrcDim++));
192  }
193  }
194  }
195  currDstDim++;
196  }
197 
198  // If the source and target shapes are compatible, both iterators must have
199  // reached the end. This condition is guaranteed by the op verifier for
200  // static shapes.
201  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
202  return reassociationMap;
203 }
204 
205 // Create a tensor.collapse_shape op that reshapes the input into the given
206 // result type.
207 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
208  Value input) {
209  auto reassociationMap =
210  createReassociationMapForCollapse(builder, input.getType(), resultType);
211  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
212  reassociationMap);
213 }
214 
215 // Create a tensor.expand_shape op that reshapes the input into the given result
216 // type.
217 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
218  Value input) {
219  auto reassociationMap =
220  createReassociationMapForCollapse(builder, resultType, input.getType());
221  return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
222  reassociationMap);
223 }
224 
225 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
226 public:
228 
229  LogicalResult
230  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
231  ConversionPatternRewriter &rewriter) const final {
232  auto loc = reshape.getLoc();
233  auto resultType = cast_if_present<ShapedType>(
234  getTypeConverter()->convertType(reshape.getType()));
235  if (!resultType) {
236  return rewriter.notifyMatchFailure(reshape.getLoc(),
237  "could not convert result type");
238  }
239  auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
240  if (!input) {
241  return rewriter.notifyMatchFailure(reshape.getLoc(),
242  "expected input type to be tensor");
243  }
244 
246  if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
247  newShape)) {
248  return failure();
249  }
250 
251  // Infer all intermediate types
252  auto inputType = inferReshapeInputType(input, newShape);
253  auto expandedType = inferReshapeExpandedType(inputType, newShape);
254  auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
255 
256  // Cast input if needed
257  auto castInput =
258  rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
259 
260  // Emit collaspe-expand pair
261  auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
262  auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
263 
264  // Cast to final result type if needed
265  auto result =
266  rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
267  rewriter.replaceOp(reshape, result);
268  return success();
269  }
270 };
271 
272 class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
273 public:
275 
276  LogicalResult
277  matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
278  ConversionPatternRewriter &rewriter) const final {
279  Location loc = sliceOp.getLoc();
280  Value input = adaptor.getInput1();
281  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
282  if (llvm::isa<UnrankedTensorType>(resultType))
283  return failure();
284 
285  ElementsAttr startElems;
286  ElementsAttr sizeElems;
287 
288  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
289  return rewriter.notifyMatchFailure(
290  sliceOp, "start of slice must be a static ranked shape");
291 
292  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
293  return rewriter.notifyMatchFailure(
294  sliceOp, "size of slice must be a static ranked shape");
295 
296  llvm::SmallVector<int64_t> sliceStarts =
297  llvm::to_vector(startElems.getValues<int64_t>());
298  llvm::SmallVector<int64_t> sliceSizes =
299  llvm::to_vector(sizeElems.getValues<int64_t>());
300 
301  SmallVector<int64_t> strides, sizes;
302  strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
303 
304  SmallVector<Value> dynSizes;
305  for (const auto &i : llvm::enumerate(sliceSizes)) {
306  int64_t size = i.value();
307  size_t index = i.index();
308  sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
309  if (!ShapedType::isDynamic(sizes.back()))
310  continue;
311 
312  auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
313  auto offset = rewriter.create<arith::ConstantOp>(
314  loc, rewriter.getIndexAttr(sliceStarts[index]));
315  dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
316  }
317 
318  auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
319  sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
320  ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
321  rewriter.getDenseI64ArrayAttr(sizes),
322  rewriter.getDenseI64ArrayAttr(strides));
323 
324  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
325 
326  // Remove const_shape ops when it no longer has use point.
327  Operation *startConstShape = sliceOp.getStart().getDefiningOp();
328  if (startConstShape->getResult(0).hasOneUse())
329  rewriter.eraseOp(startConstShape);
330 
331  Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
332  if (sizeConstShape->getResult(0).hasOneUse())
333  rewriter.eraseOp(sizeConstShape);
334 
335  return success();
336  }
337 };
338 
339 class PadConverter : public OpConversionPattern<tosa::PadOp> {
340 public:
342 
343  LogicalResult
344  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
345  ConversionPatternRewriter &rewriter) const final {
346  auto loc = padOp.getLoc();
347  auto input = padOp.getInput1();
348 
349  ElementsAttr paddingElems;
350  if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
351  return rewriter.notifyMatchFailure(
352  padOp, "padding must be a static shape value");
353  }
354  llvm::SmallVector<int64_t> paddingVals;
355  for (auto idx : paddingElems.getValues<IntegerAttr>()) {
356  paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
357  }
358 
359  ShapedType inputTy = cast<ShapedType>(input.getType());
360  int64_t rank = inputTy.getRank();
361 
362  // Setup the default constantAttr.
363 
364  Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
365  loc, padOp.getPadConst(),
366  ValueRange({rewriter.create<arith::ConstantIndexOp>(loc, 0)}));
367 
368  if (!padConstant) {
369  return rewriter.notifyMatchFailure(
370  padOp, "tosa.pad was unable to determine the pad constant value.");
371  }
372 
374  SmallVector<OpFoldResult, 3> highValues;
375 
376  lowValues.reserve(rank);
377  highValues.reserve(rank);
378 
379  for (int i = 0; i < rank; i++) {
380  Value lowVal = rewriter.create<arith::ConstantOp>(
381  loc, rewriter.getIndexAttr(paddingVals[2 * i]));
382  Value highVal = rewriter.create<arith::ConstantOp>(
383  loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
384  lowValues.push_back(lowVal);
385  highValues.push_back(highVal);
386  }
387 
388  auto newPadOp = rewriter.create<tensor::PadOp>(
389  loc, padOp.getType(), input, lowValues, highValues, padConstant);
390 
391  rewriter.replaceOp(padOp, newPadOp.getResult());
392  return success();
393  }
394 };
395 
396 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
398 
399  LogicalResult
400  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
401  ConversionPatternRewriter &rewriter) const override {
402  auto resultType = dyn_cast<RankedTensorType>(op.getType());
403 
404  Location loc = op.getLoc();
405  int axis = op.getAxis();
406  Value axisValue =
407  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
408  int64_t rank = resultType.getRank();
409 
410  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
411  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
413  tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
414 
415  // Pre-compute the offsets along the axis dimension.
416  // The axisOffsets will be of size rank + 1, where the last value
417  // will hold the total size of the tensor along the 'axis' dimension.
418  SmallVector<OpFoldResult> axisOffsets;
419  axisOffsets.push_back(rewriter.getIndexAttr(0));
420  axisOffsets.push_back(sizes[axis]);
421 
422  for (auto arg : adaptor.getOperands().drop_front()) {
423  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
424  auto currentOffset =
425  getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
426  auto total =
427  rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
428  axisOffsets.push_back(getAsOpFoldResult(total));
429  }
430  sizes[axis] = axisOffsets.back();
431 
432  // Compute the dynamic sizes of the tensor.empty operation.
433  // This is based off of the specified result type of the tosa.concat
434  // operation, since we don't want to change the result type of the operation
435  // during the conversion.
436  SmallVector<Value> dynDims;
437  for (int64_t i = 0; i < rank; ++i) {
438  if (resultType.isDynamicDim(i)) {
439  dynDims.push_back(
440  getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
441  }
442  }
443 
444  Value result = rewriter.create<tensor::EmptyOp>(
445  loc, resultType.getShape(), resultType.getElementType(), dynDims);
446 
447  for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
448  auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
449  offsets[axis] = offset;
450  result = rewriter.createOrFold<tensor::InsertSliceOp>(
451  loc, arg, result, offsets, sizes, strides);
452  }
453  rewriter.replaceOp(op, result);
454  return success();
455  }
456 };
457 
458 } // namespace
459 
461  const TypeConverter &converter, RewritePatternSet *patterns) {
462  patterns
463  ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
464  converter, patterns->getContext());
465 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:362
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369