MLIR  22.0.0git
TensorInferTypeOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 
18 /// For reshape op compute the shape at dimension `dimIndex` of the output in
19 /// terms of shape of the `src`, when the reshape op is a collapsing
20 /// operation. It is the product of the shape of the collapsed dimensions of the
21 /// `src`.
23  OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
24  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
25  if (ShapedType::isStatic(dstStaticShape[dimIndex])) {
26  // Static dimension: return Attribute.
27  return builder.getIndexAttr(dstStaticShape[dimIndex]);
28  }
29  AffineMap map = reassociationMap[dimIndex];
30  unsigned startPos =
31  cast<AffineDimExpr>(map.getResults().front()).getPosition();
32  unsigned endPos = cast<AffineDimExpr>(map.getResults().back()).getPosition();
33  AffineExpr expr;
34  SmallVector<OpFoldResult> dynamicDims;
35  for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
36  dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
37  AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
38  expr = (expr ? expr * currExpr : currExpr);
39  }
40 
41  // Dynamic dimension: return Value.
43  builder, loc, AffineMap::get(0, endPos - startPos + 1, expr),
44  dynamicDims)
45  ->getResult(0);
46 }
47 
48 /// Given the `src` of a collapsing reshape op and its reassociation maps,
49 /// compute the shape of the result of the reshape.
51  OpBuilder &builder, Location loc, Value src,
52  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
53  return llvm::to_vector<4>(llvm::map_range(
54  llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
55  return getCollapsedOutputDimFromInputShape(
56  builder, loc, dim, src, dstStaticShape, reassociation);
57  }));
58 }
59 
61  : public ReifyRankedShapedTypeOpInterface::ExternalModel<
62  ReifyCollapseShapeOp, CollapseShapeOp> {
63  LogicalResult
65  ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
66  auto loc = op->getLoc();
67  auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
68  reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
69  b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
70  reshapeOp.getReassociationMaps()));
71  return success();
72  }
73 };
74 
75 namespace {
76 
77 struct ReifyExpandShapeOp
78  : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
79  ExpandShapeOp> {
80  LogicalResult
83  auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
84  SmallVector<OpFoldResult> resultShapes =
85  expandShapeOp.getMixedOutputShape();
86  reifyResultShapes.emplace_back(std::move(resultShapes));
87  return success();
88  }
89 };
90 
91 struct ReifyPadOp
92  : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
93  PadOp> {
94  LogicalResult
96  ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
97  auto padOp = cast<PadOp>(op);
98  Location loc = padOp.getLoc();
99  auto lowPad = padOp.getMixedLowPad();
100  auto highPad = padOp.getMixedHighPad();
102  for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
103  if (!padOp.getResultType().isDynamicDim(dim)) {
104  shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim)));
105  continue;
106  }
107 
108  // Shape along each dimension is source dim + low pad + high pad.
109  SmallVector<OpFoldResult> mapOperands;
110  mapOperands.push_back(
111  b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
112  mapOperands.push_back(lowPad[dim]);
113  mapOperands.push_back(highPad[dim]);
114  AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) +
115  b.getAffineSymbolExpr(1);
116  shapes.push_back(getValueOrCreateConstantIndexOp(
117  b, loc,
119  b, loc, AffineMap::get(1, 2, expr), mapOperands)));
120  }
121  reifiedReturnShapes.emplace_back(std::move(shapes));
122  return success();
123  }
124 };
125 
126 } // namespace
127 
129  DialectRegistry &registry) {
130  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
131  ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
132  CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
133  PadOp::attachInterface<ReifyPadOp>(*ctx);
134  });
135 }
static OpFoldResult getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociationMap)
For reshape op compute the shape at dimension dimIndex of the output in terms of shape of the src,...
static SmallVector< OpFoldResult, 4 > getCollapsedOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation)
Given the src of a collapsing reshape op and its reassociation maps, compute the shape of the result ...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1278
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
void registerInferTypeOpInterfaceExternalModels(mlir::DialectRegistry &registry)
Registers external models for Infer Type interfaces for tensor ops.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) const