12 #include "llvm/ADT/Sequence.h"
21 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
28 return elementsAttr.getShapedType().getElementType();
32 return elementsAttr.getShapedType().getNumElements();
37 int64_t rank = type.getRank();
38 if (rank == 0 && index.size() == 1 && index[0] == 0)
40 if (rank !=
static_cast<int64_t
>(index.size()))
45 return llvm::all_of(llvm::seq<int>(0, rank), [&](
int i) {
46 int64_t dim =
static_cast<int64_t
>(index[i]);
47 return 0 <= dim && dim < shape[i];
50 bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
52 return isValidIndex(elementsAttr.getShapedType(), index);
56 ShapedType shapeType = llvm::cast<ShapedType>(type);
57 assert(isValidIndex(shapeType, index) &&
58 "expected valid multi-dimensional index");
62 auto rank = shapeType.getRank();
64 uint64_t valueIndex = 0;
65 uint64_t dimMultiplier = 1;
66 for (
int i = rank - 1; i >= 0; --i) {
67 valueIndex += index[i] * dimMultiplier;
68 dimMultiplier *= shape[i];
80 if (m.getNumDims() != shape.size())
81 return emitError() <<
"memref layout mismatch between rank and affine map: "
82 << shape.size() <<
" != " << m.getNumDims();
93 if (
auto dim = dyn_cast<AffineDimExpr>(e))
94 strides[dim.getPosition()] =
95 strides[dim.getPosition()] + multiplicativeFactor;
97 offset = offset + e * multiplicativeFactor;
108 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
120 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
122 strides[dim.getPosition()] =
123 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
130 if (bin.getLHS().isSymbolicOrConstant())
131 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
133 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
139 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
141 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
142 return success(succeeded(res1) && succeeded(res2));
145 llvm_unreachable(
"unexpected binary operation");
164 if (m.getNumResults() != 1 && !m.isIdentity())
170 strides.assign(shape.size(), zero);
173 if (m.isIdentity()) {
180 assert(
false &&
"unexpected failure: extract strides in canonical layout");
193 unsigned numDims = m.getNumDims();
194 unsigned numSymbols = m.getNumSymbols();
196 for (
auto &stride : strides)
209 if (
auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
210 offset = cst.getValue();
212 offset = ShapedType::kDynamic;
213 for (
auto e : strideExprs) {
214 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(e))
215 strides.push_back(c.getValue());
217 strides.push_back(ShapedType::kDynamic);
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult getAffineMapStridesAndOffset(AffineMap map, ArrayRef< int64_t > shape, SmallVectorImpl< int64_t > &strides, int64_t &offset)
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.