MLIR 23.0.0git
TensorInferTypeOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "llvm/ADT/SmallVectorExtras.h"
15
16using namespace mlir;
17using namespace mlir::tensor;
18
19/// For reshape op compute the shape at dimension `dimIndex` of the output in
20/// terms of shape of the `src`, when the reshape op is a collapsing
21/// operation. It is the product of the shape of the collapsed dimensions of the
22/// `src`.
24 OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
25 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
26 if (ShapedType::isStatic(dstStaticShape[dimIndex])) {
27 // Static dimension: return Attribute.
28 return builder.getIndexAttr(dstStaticShape[dimIndex]);
29 }
30 AffineMap map = reassociationMap[dimIndex];
31 unsigned startPos =
32 cast<AffineDimExpr>(map.getResults().front()).getPosition();
33 unsigned endPos = cast<AffineDimExpr>(map.getResults().back()).getPosition();
34 AffineExpr expr;
35 SmallVector<OpFoldResult> dynamicDims;
36 for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
37 dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
38 AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
39 expr = (expr ? expr * currExpr : currExpr);
40 }
41
42 // Dynamic dimension: return Value.
44 builder, loc, AffineMap::get(0, endPos - startPos + 1, expr),
45 dynamicDims)
46 ->getResult(0);
47}
48
49/// Given the `src` of a collapsing reshape op and its reassociation maps,
50/// compute the shape of the result of the reshape.
52 OpBuilder &builder, Location loc, Value src,
53 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
54 return llvm::map_to_vector<4>(
55 llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
56 return getCollapsedOutputDimFromInputShape(
57 builder, loc, dim, src, dstStaticShape, reassociation);
58 });
59}
60
62 : public ReifyRankedShapedTypeOpInterface::ExternalModel<
63 ReifyCollapseShapeOp, CollapseShapeOp> {
64 LogicalResult
66 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
67 auto loc = op->getLoc();
68 auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
69 reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
70 b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
71 reshapeOp.getReassociationMaps()));
72 return success();
73 }
74};
75
76namespace {
77
78struct ReifyExpandShapeOp
79 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
80 ExpandShapeOp> {
81 using Base =
82 ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
83 ExpandShapeOp>;
84 LogicalResult
85 reifyResultShapes(Operation *op, OpBuilder &b,
87 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
88 SmallVector<OpFoldResult> resultShapes =
89 expandShapeOp.getMixedOutputShape();
90 reifyResultShapes.emplace_back(std::move(resultShapes));
91 return success();
92 }
93};
94
95struct ReifyPadOp
96 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
97 PadOp> {
98 LogicalResult
99 reifyResultShapes(Operation *op, OpBuilder &b,
100 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
101 auto padOp = cast<PadOp>(op);
102 Location loc = padOp.getLoc();
103 auto lowPad = padOp.getMixedLowPad();
104 auto highPad = padOp.getMixedHighPad();
105 SmallVector<OpFoldResult> shapes;
106 for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
107 if (!padOp.getResultType().isDynamicDim(dim)) {
108 shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim)));
109 continue;
110 }
111
112 // Shape along each dimension is source dim + low pad + high pad.
113 SmallVector<OpFoldResult> mapOperands;
114 mapOperands.push_back(
115 b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
116 mapOperands.push_back(lowPad[dim]);
117 mapOperands.push_back(highPad[dim]);
118 AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) +
119 b.getAffineSymbolExpr(1);
120 shapes.push_back(getValueOrCreateConstantIndexOp(
121 b, loc,
123 b, loc, AffineMap::get(1, 2, expr), mapOperands)));
124 }
125 reifiedReturnShapes.emplace_back(std::move(shapes));
126 return success();
127 }
128};
129
130} // namespace
131
133 DialectRegistry &registry) {
134 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
135 ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
136 CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
137 PadOp::attachInterface<ReifyPadOp>(*ctx);
138 });
139}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static OpFoldResult getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociationMap)
For reshape op compute the shape at dimension dimIndex of the output in terms of shape of the src,...
static SmallVector< OpFoldResult, 4 > getCollapsedOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation)
Given the src of a collapsing reshape op and its reassociation maps, compute the shape of the result ...
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void registerInferTypeOpInterfaceExternalModels(mlir::DialectRegistry &registry)
Registers external models for Infer Type interfaces for tensor ops.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) const