MLIR 22.0.0git
BuiltinAttributeInterfaces.cpp
Go to the documentation of this file.
1//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Diagnostics.h"
12#include "llvm/ADT/Sequence.h"
13
14using namespace mlir;
15using namespace mlir::detail;
16
17//===----------------------------------------------------------------------===//
18/// Tablegen Interface Definitions
19//===----------------------------------------------------------------------===//
20
21#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// ElementsAttr
25//===----------------------------------------------------------------------===//
26
27Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
28 return elementsAttr.getShapedType().getElementType();
29}
30
31int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
32 return elementsAttr.getShapedType().getNumElements();
33}
34
35bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
36 // Verify that the rank of the indices matches the held type.
37 int64_t rank = type.getRank();
38 if (rank == 0 && index.size() == 1 && index[0] == 0)
39 return true;
40 if (rank != static_cast<int64_t>(index.size()))
41 return false;
42
43 // Verify that all of the indices are within the shape dimensions.
44 ArrayRef<int64_t> shape = type.getShape();
45 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
46 int64_t dim = static_cast<int64_t>(index[i]);
47 return 0 <= dim && dim < shape[i];
48 });
49}
50bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
52 return isValidIndex(elementsAttr.getShapedType(), index);
53}
54
55uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
56 ShapedType shapeType = llvm::cast<ShapedType>(type);
57 assert(isValidIndex(shapeType, index) &&
58 "expected valid multi-dimensional index");
59
60 // Reduce the provided multidimensional index into a flattended 1D row-major
61 // index.
62 auto rank = shapeType.getRank();
63 ArrayRef<int64_t> shape = shapeType.getShape();
64 uint64_t valueIndex = 0;
65 uint64_t dimMultiplier = 1;
66 for (int i = rank - 1; i >= 0; --i) {
67 valueIndex += index[i] * dimMultiplier;
68 dimMultiplier *= shape[i];
69 }
70 return valueIndex;
71}
72
73//===----------------------------------------------------------------------===//
74// MemRefLayoutAttrInterface
75//===----------------------------------------------------------------------===//
76
80 if (m.getNumDims() != shape.size())
81 return emitError() << "memref layout mismatch between rank and affine map: "
82 << shape.size() << " != " << m.getNumDims();
83
84 return success();
85}
86
87// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
88// i.e. single term). Accumulate the AffineExpr into the existing one.
90 AffineExpr multiplicativeFactor,
92 AffineExpr &offset) {
93 if (auto dim = dyn_cast<AffineDimExpr>(e))
94 strides[dim.getPosition()] =
95 strides[dim.getPosition()] + multiplicativeFactor;
96 else
97 offset = offset + e * multiplicativeFactor;
98}
99
100/// Takes a single AffineExpr `e` and populates the `strides` array with the
101/// strides expressions for each dim position.
102/// The convention is that the strides for dimensions d0, .. dn appear in
103/// order to make indexing intuitive into the result.
104static LogicalResult extractStrides(AffineExpr e,
105 AffineExpr multiplicativeFactor,
107 AffineExpr &offset) {
108 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
109 if (!bin) {
110 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
111 return success();
112 }
113
114 if (bin.getKind() == AffineExprKind::CeilDiv ||
115 bin.getKind() == AffineExprKind::FloorDiv ||
116 bin.getKind() == AffineExprKind::Mod)
117 return failure();
118
119 if (bin.getKind() == AffineExprKind::Mul) {
120 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
121 if (dim) {
122 strides[dim.getPosition()] =
123 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
124 return success();
125 }
126 // LHS and RHS may both contain complex expressions of dims. Try one path
127 // and if it fails try the other. This is guaranteed to succeed because
128 // only one path may have a `dim`, otherwise this is not an AffineExpr in
129 // the first place.
130 if (bin.getLHS().isSymbolicOrConstant())
131 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
132 strides, offset);
133 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
134 strides, offset);
135 }
136
137 if (bin.getKind() == AffineExprKind::Add) {
138 auto res1 =
139 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
140 auto res2 =
141 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
142 return success(succeeded(res1) && succeeded(res2));
143 }
144
145 llvm_unreachable("unexpected binary operation");
146}
147
148/// A stride specification is a list of integer values that are either static
149/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
150/// the distance in the number of elements between successive entries along a
151/// particular dimension.
152///
153/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
154/// non-contiguous memory region of `42` by `16` `f32` elements in which the
155/// distance between two consecutive elements along the outer dimension is `1`
156/// and the distance between two consecutive elements along the inner dimension
157/// is `64`.
158///
159/// The convention is that the strides for dimensions d0, .. dn appear in
160/// order to make indexing intuitive into the result.
163 AffineExpr &offset) {
164 if (m.getNumResults() != 1 && !m.isIdentity())
165 return failure();
166
167 auto zero = getAffineConstantExpr(0, m.getContext());
168 auto one = getAffineConstantExpr(1, m.getContext());
169 offset = zero;
170 strides.assign(shape.size(), zero);
171
172 // Canonical case for empty map.
173 if (m.isIdentity()) {
174 // 0-D corner case, offset is already 0.
175 if (shape.empty())
176 return success();
177 auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
178 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
179 return success();
180 assert(false && "unexpected failure: extract strides in canonical layout");
181 }
182
183 // Non-canonical case requires more work.
184 auto stridedExpr =
186 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
187 offset = AffineExpr();
188 strides.clear();
189 return failure();
190 }
191
192 // Simplify results to allow folding to constants and simple checks.
193 unsigned numDims = m.getNumDims();
194 unsigned numSymbols = m.getNumSymbols();
195 offset = simplifyAffineExpr(offset, numDims, numSymbols);
196 for (auto &stride : strides)
197 stride = simplifyAffineExpr(stride, numDims, numSymbols);
198
199 return success();
200}
201
204 int64_t &offset) {
205 AffineExpr offsetExpr;
206 SmallVector<AffineExpr, 4> strideExprs;
207 if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
208 return failure();
209 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
210 offset = cst.getValue();
211 else
212 offset = ShapedType::kDynamic;
213 for (auto e : strideExprs) {
214 if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
215 strides.push_back(c.getValue());
216 else
217 strides.push_back(ShapedType::kDynamic);
218 }
219 return success();
220}
return success()
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
AttrTypeReplacer.
LogicalResult getAffineMapStridesAndOffset(AffineMap map, ArrayRef< int64_t > shape, SmallVectorImpl< int64_t > &strides, int64_t &offset)
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152