MLIR 22.0.0git
TosaToTensor.cpp
Go to the documentation of this file.
1//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Tensor dialect.
10//
11//===----------------------------------------------------------------------===//
12
21#include "llvm/ADT/STLExtras.h"
22
23#include <numeric>
24
25using namespace mlir;
26using namespace tosa;
27
28namespace {
29
30// Infer the type to which the input of a 'tosa.reshape' op must be cast when
31// lowered.
32TensorType inferReshapeInputType(TypedValue<TensorType> input,
33 ArrayRef<int64_t> newShape) {
34 // No need to cast input for non-empty target shape
35 if (!newShape.empty())
36 return input.getType();
37
38 // The input type must be cast into a tensor with the same rank and all static
39 // dimensions set to 1. This prevents the generation of a
40 // tensor.collapse_shape op that converts a dynamically shaped tensor into a
41 // 0D tensor. While such construct is not incorrect on its own, bufferization
42 // cannot properly handle it at the moment, so we avoid it.
43 SmallVector<int64_t> shape(input.getType().getRank(), 1);
44 return input.getType().clone(shape);
45}
46
47// Infer the result type of 'tensor.expand_shape' in the collapse-expand
48// pair emitted for a 'tosa.reshape' op.
49TensorType inferReshapeExpandedType(TensorType inputType,
50 ArrayRef<int64_t> newShape) {
51 // Special case for 0D output tensor. Note: Watch out when using Type::clone()
52 // with just '{}', as it will invoke the incorrect overload.
53 if (newShape.empty())
54 return inputType.clone(ArrayRef<int64_t>{});
55
56 // Check if the input is static, and if so, get its total size
57 bool inputIsStatic = inputType.hasStaticShape();
58 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
59
60 // Compute result shape
61 auto resultShape =
62 llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
63 // If this is not a placeholder, do not change it.
64 if (size >= 0)
65 return size;
66
67 // If we do not know the total size of the tensor, keep this dimension
68 // dynamic in the result shape.
69 if (!inputIsStatic)
70 return ShapedType::kDynamic;
71
72 // Calculate the product of all elements in 'newShape' except for the -1
73 // placeholder, which we discard by negating the result.
74 int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape);
75
76 // If there is a 0 component in 'newShape', resolve the placeholder as
77 // 0.
78 if (totalSizeNoPlaceholder == 0)
79 return 0;
80
81 // Resolve the placeholder as the quotient between the total tensor size
82 // and the product of all other sizes.
83 return totalSize / totalSizeNoPlaceholder;
84 });
85
86 bool resultIsStatic = ShapedType::isStaticShape(resultShape);
87
88 // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
89 // shaped input from being reshaped into a statically shaped result. We may
90 // simply turn the first result dimension dynamic to address this.
91 if (!inputIsStatic && resultIsStatic)
92 resultShape[0] = ShapedType::kDynamic;
93
94 // The 'tensor.expand_shape' op also forbids a statically shaped input from
95 // being reshaped into a dynamically shaped result, but the placeholder
96 // inference algorithm above guarantees that this will never be the case.
97 assert(!inputIsStatic || resultIsStatic);
98
99 // Create result type
100 return inputType.clone(resultShape);
101}
102
103// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
104// pair emitted for a 'tosa.reshape' op.
105TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
106 auto lhsShape = lhsType.getShape();
107 auto rhsShape = rhsType.getShape();
108
109 if (lhsShape.empty() || rhsShape.empty())
110 return lhsType.clone(ArrayRef<int64_t>{});
111
112 if (ShapedType::isDynamicShape(lhsShape) ||
113 ShapedType::isDynamicShape(rhsShape))
114 return lhsType.clone({ShapedType::kDynamic});
115
116 SmallVector<int64_t> intermediateShape;
117 unsigned currLhsDim = 0, currRhsDim = 0;
118 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
119 int64_t rhsSize = rhsShape[currRhsDim];
120 int64_t lhsSize = lhsShape[currLhsDim];
121 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
122 currRhsDim < rhsShape.size()) {
123 if (lhsSize < rhsSize) {
124 currLhsDim++;
125 if (currLhsDim < lhsShape.size()) {
126 lhsSize *= lhsShape[currLhsDim];
127 }
128 } else {
129 currRhsDim++;
130 if (currRhsDim < rhsShape.size()) {
131 rhsSize *= rhsShape[currRhsDim];
132 }
133 }
134 }
135 if (lhsSize == rhsSize) {
136 intermediateShape.push_back(lhsSize);
137 }
138 currRhsDim++;
139 currLhsDim++;
140 }
141
142 // Static shapes are guaranteed to be compatible by the op verifier, so all
143 // leftover dimensions should be 1.
144 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
145 assert(lhsShape[currLhsDim] == 1);
146 }
147 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
148 assert(rhsShape[currRhsDim] == 1);
149 }
150
151 return lhsType.clone(intermediateShape);
152}
153
155createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
156 Type dstType) {
157 auto srcShape = cast<TensorType>(srcType).getShape();
158 auto dstShape = cast<TensorType>(dstType).getShape();
159
160 if (srcShape.empty() || dstShape.empty())
161 return {};
162
163 if (ShapedType::isDynamicShape(srcShape) ||
164 ShapedType::isDynamicShape(dstShape)) {
165 assert(dstShape.size() == 1);
167 for (auto i : llvm::seq<int64_t>(srcShape.size()))
168 exprs.push_back(builder.getAffineDimExpr(i));
169 return {exprs};
170 }
171
172 SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
173 unsigned currSrcDim = 0, currDstDim = 0;
174 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
175 int64_t dstSize = dstShape[currDstDim];
176 int64_t srcSize = srcShape[currSrcDim];
177 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
178 reassociationMap[currDstDim].push_back(
179 builder.getAffineDimExpr(currSrcDim++));
180 srcSize *= srcShape[currSrcDim];
181 }
182 if (srcSize == dstSize) {
183 reassociationMap[currDstDim].push_back(
184 builder.getAffineDimExpr(currSrcDim++));
185 // If the next dim in collapsedShape is not 1, treat subsequent dims in
186 // expandedShape which are 1 to be collapsed.
187 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
188 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
189 reassociationMap[currDstDim].push_back(
190 builder.getAffineDimExpr(currSrcDim++));
191 }
192 }
193 }
194 currDstDim++;
195 }
196
197 // If the source and target shapes are compatible, both iterators must have
198 // reached the end. This condition is guaranteed by the op verifier for
199 // static shapes.
200 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
201 return reassociationMap;
202}
203
204// Create a tensor.collapse_shape op that reshapes the input into the given
205// result type.
206Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
207 Value input) {
208 auto reassociationMap =
209 createReassociationMapForCollapse(builder, input.getType(), resultType);
210 return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
211 reassociationMap);
212}
213
214// Create a tensor.expand_shape op that reshapes the input into the given result
215// type.
216Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
217 Value input) {
218 auto reassociationMap =
219 createReassociationMapForCollapse(builder, resultType, input.getType());
220 return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
221 reassociationMap);
222}
223
224class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
225public:
226 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
227
228 LogicalResult
229 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const final {
231 auto loc = reshape.getLoc();
232 auto resultType =
233 getTypeConverter()->convertType<ShapedType>(reshape.getType());
234 if (!resultType) {
235 return rewriter.notifyMatchFailure(reshape.getLoc(),
236 "could not convert result type");
237 }
238 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
239 if (!input) {
240 return rewriter.notifyMatchFailure(reshape.getLoc(),
241 "expected input type to be tensor");
242 }
243
244 llvm::SmallVector<int64_t> newShape;
245 if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
246 newShape)) {
247 return failure();
248 }
249
250 // Infer all intermediate types
251 auto inputType = inferReshapeInputType(input, newShape);
252 auto expandedType = inferReshapeExpandedType(inputType, newShape);
253 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
254
255 // Cast input if needed
256 auto castInput =
257 rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
258
259 // Emit collaspe-expand pair
260 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
261 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
262
263 // Cast to final result type if needed
264 auto result =
265 rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
266 rewriter.replaceOp(reshape, result);
267 return success();
268 }
269};
270
271class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
272public:
273 using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
274
275 LogicalResult
276 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter) const final {
278 Location loc = sliceOp.getLoc();
279 Value input = adaptor.getInput1();
280 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
281 if (llvm::isa<UnrankedTensorType>(resultType))
282 return failure();
283
284 ElementsAttr startElems;
285 ElementsAttr sizeElems;
286
287 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
288 return rewriter.notifyMatchFailure(
289 sliceOp, "start of slice must be a static ranked shape");
290
291 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
292 return rewriter.notifyMatchFailure(
293 sliceOp, "size of slice must be a static ranked shape");
294
295 llvm::SmallVector<int64_t> sliceStarts =
296 llvm::to_vector(startElems.getValues<int64_t>());
297 llvm::SmallVector<int64_t> sliceSizes =
298 llvm::to_vector(sizeElems.getValues<int64_t>());
299
300 SmallVector<int64_t> strides, sizes;
301 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
302
303 SmallVector<Value> dynSizes;
304 for (const auto &i : llvm::enumerate(sliceSizes)) {
305 int64_t size = i.value();
306 size_t index = i.index();
307 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
308 if (ShapedType::isStatic(sizes.back()))
309 continue;
310
311 auto dim = tensor::DimOp::create(rewriter, loc, input, index);
312 auto offset = arith::ConstantOp::create(
313 rewriter, loc, rewriter.getIndexAttr(sliceStarts[index]));
314 dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset));
315 }
316
317 auto newSliceOp = tensor::ExtractSliceOp::create(
318 rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}),
319 dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
320 rewriter.getDenseI64ArrayAttr(sizes),
321 rewriter.getDenseI64ArrayAttr(strides));
322
323 // Remove const_shape ops when it no longer has use point.
324 Operation *startConstShape = sliceOp.getStart().getDefiningOp();
325 if (startConstShape->getResult(0).hasOneUse())
326 rewriter.eraseOp(startConstShape);
327
328 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
329 if (sizeConstShape->getResult(0).hasOneUse())
330 rewriter.eraseOp(sizeConstShape);
331
332 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
333 return success();
334 }
335};
336
337class PadConverter : public OpConversionPattern<tosa::PadOp> {
338public:
339 using OpConversionPattern::OpConversionPattern;
340
341 LogicalResult
342 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const final {
344 auto loc = padOp.getLoc();
345 auto input = padOp.getInput1();
346
347 ElementsAttr paddingElems;
348 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
349 return rewriter.notifyMatchFailure(
350 padOp, "padding must be a static shape value");
351 }
352 llvm::SmallVector<int64_t> paddingVals;
353 for (auto idx : paddingElems.getValues<IntegerAttr>()) {
354 paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
355 }
356
357 ShapedType inputTy = cast<ShapedType>(input.getType());
358 int64_t rank = inputTy.getRank();
359
360 // Setup the default constantAttr.
361
362 Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
363 loc, padOp.getPadConst(),
364 ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)}));
365
366 if (!padConstant) {
367 return rewriter.notifyMatchFailure(
368 padOp, "tosa.pad was unable to determine the pad constant value.");
369 }
370
371 SmallVector<OpFoldResult, 3> lowValues;
372 SmallVector<OpFoldResult, 3> highValues;
373
374 lowValues.reserve(rank);
375 highValues.reserve(rank);
376
377 for (int i = 0; i < rank; i++) {
378 Value lowVal = arith::ConstantOp::create(
379 rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i]));
380 Value highVal = arith::ConstantOp::create(
381 rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
382 lowValues.push_back(lowVal);
383 highValues.push_back(highVal);
384 }
385
386 auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input,
387 lowValues, highValues, padConstant);
388
389 rewriter.replaceOp(padOp, newPadOp.getResult());
390 return success();
391 }
392};
393
394struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
395 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
396
397 LogicalResult
398 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const override {
400 auto resultType = dyn_cast<RankedTensorType>(op.getType());
401
402 Location loc = op.getLoc();
403 int axis = op.getAxis();
404 Value axisValue =
405 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis));
406 int64_t rank = resultType.getRank();
407
408 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
409 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
410 SmallVector<OpFoldResult> sizes =
411 tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
412
413 // Pre-compute the offsets along the axis dimension.
414 // The axisOffsets will be of size rank + 1, where the last value
415 // will hold the total size of the tensor along the 'axis' dimension.
416 SmallVector<OpFoldResult> axisOffsets;
417 axisOffsets.push_back(rewriter.getIndexAttr(0));
418 axisOffsets.push_back(sizes[axis]);
419
420 for (auto arg : adaptor.getOperands().drop_front()) {
421 auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
422 auto currentOffset =
423 getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
424 auto total =
425 rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
426 axisOffsets.push_back(getAsOpFoldResult(total));
427 }
428 sizes[axis] = axisOffsets.back();
429
430 // Compute the dynamic sizes of the tensor.empty operation.
431 // This is based off of the specified result type of the tosa.concat
432 // operation, since we don't want to change the result type of the operation
433 // during the conversion.
434 SmallVector<Value> dynDims;
435 for (int64_t i = 0; i < rank; ++i) {
436 if (resultType.isDynamicDim(i)) {
437 dynDims.push_back(
438 getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
439 }
440 }
441
442 Value result =
443 tensor::EmptyOp::create(rewriter, loc, resultType.getShape(),
444 resultType.getElementType(), dynDims);
445
446 for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
447 auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
448 offsets[axis] = offset;
449 result = rewriter.createOrFold<tensor::InsertSliceOp>(
450 loc, arg, result, offsets, sizes, strides);
451 }
452 rewriter.replaceOp(op, result);
453 return success();
454 }
455};
456
457} // namespace
458
460 const TypeConverter &converter, RewritePatternSet *patterns) {
462 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
463 converter, patterns->getContext());
464}
return success()
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369