MLIR  21.0.0git
BuiltinAttributeInterfaces.cpp
Go to the documentation of this file.
1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "llvm/ADT/Sequence.h"
13 
14 using namespace mlir;
15 using namespace mlir::detail;
16 
17 //===----------------------------------------------------------------------===//
18 /// Tablegen Interface Definitions
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // ElementsAttr
25 //===----------------------------------------------------------------------===//
26 
27 Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
28  return elementsAttr.getShapedType().getElementType();
29 }
30 
31 int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
32  return elementsAttr.getShapedType().getNumElements();
33 }
34 
35 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
36  // Verify that the rank of the indices matches the held type.
37  int64_t rank = type.getRank();
38  if (rank == 0 && index.size() == 1 && index[0] == 0)
39  return true;
40  if (rank != static_cast<int64_t>(index.size()))
41  return false;
42 
43  // Verify that all of the indices are within the shape dimensions.
44  ArrayRef<int64_t> shape = type.getShape();
45  return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
46  int64_t dim = static_cast<int64_t>(index[i]);
47  return 0 <= dim && dim < shape[i];
48  });
49 }
50 bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
51  ArrayRef<uint64_t> index) {
52  return isValidIndex(elementsAttr.getShapedType(), index);
53 }
54 
55 uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
56  ShapedType shapeType = llvm::cast<ShapedType>(type);
57  assert(isValidIndex(shapeType, index) &&
58  "expected valid multi-dimensional index");
59 
60  // Reduce the provided multidimensional index into a flattended 1D row-major
61  // index.
62  auto rank = shapeType.getRank();
63  ArrayRef<int64_t> shape = shapeType.getShape();
64  uint64_t valueIndex = 0;
65  uint64_t dimMultiplier = 1;
66  for (int i = rank - 1; i >= 0; --i) {
67  valueIndex += index[i] * dimMultiplier;
68  dimMultiplier *= shape[i];
69  }
70  return valueIndex;
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // MemRefLayoutAttrInterface
75 //===----------------------------------------------------------------------===//
76 
80  if (m.getNumDims() != shape.size())
81  return emitError() << "memref layout mismatch between rank and affine map: "
82  << shape.size() << " != " << m.getNumDims();
83 
84  return success();
85 }
86 
87 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
88 // i.e. single term). Accumulate the AffineExpr into the existing one.
90  AffineExpr multiplicativeFactor,
92  AffineExpr &offset) {
93  if (auto dim = dyn_cast<AffineDimExpr>(e))
94  strides[dim.getPosition()] =
95  strides[dim.getPosition()] + multiplicativeFactor;
96  else
97  offset = offset + e * multiplicativeFactor;
98 }
99 
100 /// Takes a single AffineExpr `e` and populates the `strides` array with the
101 /// strides expressions for each dim position.
102 /// The convention is that the strides for dimensions d0, .. dn appear in
103 /// order to make indexing intuitive into the result.
104 static LogicalResult extractStrides(AffineExpr e,
105  AffineExpr multiplicativeFactor,
107  AffineExpr &offset) {
108  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
109  if (!bin) {
110  extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
111  return success();
112  }
113 
114  if (bin.getKind() == AffineExprKind::CeilDiv ||
115  bin.getKind() == AffineExprKind::FloorDiv ||
116  bin.getKind() == AffineExprKind::Mod)
117  return failure();
118 
119  if (bin.getKind() == AffineExprKind::Mul) {
120  auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
121  if (dim) {
122  strides[dim.getPosition()] =
123  strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
124  return success();
125  }
126  // LHS and RHS may both contain complex expressions of dims. Try one path
127  // and if it fails try the other. This is guaranteed to succeed because
128  // only one path may have a `dim`, otherwise this is not an AffineExpr in
129  // the first place.
130  if (bin.getLHS().isSymbolicOrConstant())
131  return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
132  strides, offset);
133  return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
134  strides, offset);
135  }
136 
137  if (bin.getKind() == AffineExprKind::Add) {
138  auto res1 =
139  extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
140  auto res2 =
141  extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
142  return success(succeeded(res1) && succeeded(res2));
143  }
144 
145  llvm_unreachable("unexpected binary operation");
146 }
147 
148 /// A stride specification is a list of integer values that are either static
149 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode
150 /// the distance in the number of elements between successive entries along a
151 /// particular dimension.
152 ///
153 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
154 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
155 /// distance between two consecutive elements along the outer dimension is `1`
156 /// and the distance between two consecutive elements along the inner dimension
157 /// is `64`.
158 ///
159 /// The convention is that the strides for dimensions d0, .. dn appear in
160 /// order to make indexing intuitive into the result.
161 static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
163  AffineExpr &offset) {
164  if (m.getNumResults() != 1 && !m.isIdentity())
165  return failure();
166 
167  auto zero = getAffineConstantExpr(0, m.getContext());
168  auto one = getAffineConstantExpr(1, m.getContext());
169  offset = zero;
170  strides.assign(shape.size(), zero);
171 
172  // Canonical case for empty map.
173  if (m.isIdentity()) {
174  // 0-D corner case, offset is already 0.
175  if (shape.empty())
176  return success();
177  auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
178  if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
179  return success();
180  assert(false && "unexpected failure: extract strides in canonical layout");
181  }
182 
183  // Non-canonical case requires more work.
184  auto stridedExpr =
185  simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
186  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
187  offset = AffineExpr();
188  strides.clear();
189  return failure();
190  }
191 
192  // Simplify results to allow folding to constants and simple checks.
193  unsigned numDims = m.getNumDims();
194  unsigned numSymbols = m.getNumSymbols();
195  offset = simplifyAffineExpr(offset, numDims, numSymbols);
196  for (auto &stride : strides)
197  stride = simplifyAffineExpr(stride, numDims, numSymbols);
198 
199  return success();
200 }
201 
204  int64_t &offset) {
205  AffineExpr offsetExpr;
206  SmallVector<AffineExpr, 4> strideExprs;
207  if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
208  return failure();
209  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
210  offset = cst.getValue();
211  else
212  offset = ShapedType::kDynamic;
213  for (auto e : strideExprs) {
214  if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
215  strides.push_back(c.getValue());
216  else
217  strides.push_back(ShapedType::kDynamic);
218  }
219  return success();
220 }
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
AttrTypeReplacer.
LogicalResult getAffineMapStridesAndOffset(AffineMap map, ArrayRef< int64_t > shape, SmallVectorImpl< int64_t > &strides, int64_t &offset)
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.