MLIR 22.0.0git
DimLvlMap.cpp
Go to the documentation of this file.
1//===- DimLvlMap.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "DimLvlMap.h"
10
11using namespace mlir;
12using namespace mlir::sparse_tensor;
13using namespace mlir::sparse_tensor::ir_detail;
14
15//===----------------------------------------------------------------------===//
16// `DimLvlExpr` implementation.
17//===----------------------------------------------------------------------===//
18
20 return SymVar(llvm::cast<AffineSymbolExpr>(expr));
21}
22
23std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
24 if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
25 return SymVar(s);
26 return std::nullopt;
27}
28
30 return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(expr));
31}
32
33std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
34 if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
35 return Var(getAllowedVarKind(), x);
36 return std::nullopt;
37}
38
39std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
41 const auto ak = getAffineKind();
42 const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
43 const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
44 const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
45 return {lhs, ak, rhs};
46}
47
48//===----------------------------------------------------------------------===//
49// `DimSpec` implementation.
50//===----------------------------------------------------------------------===//
51
52DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
53 : var(var), expr(expr), slice(slice) {}
54
55bool DimSpec::isValid(Ranks const &ranks) const {
56 // Nothing in `slice` needs additional validation.
57 // We explicitly consider null-expr to be vacuously valid.
58 return ranks.isValid(var) && (!expr || ranks.isValid(expr));
59}
60
61//===----------------------------------------------------------------------===//
62// `LvlSpec` implementation.
63//===----------------------------------------------------------------------===//
64
66 : var(var), expr(expr), type(type) {
67 assert(expr);
68 assert(isValidLT(type) && !isUndefLT(type));
69}
70
71bool LvlSpec::isValid(Ranks const &ranks) const {
72 // Nothing in `type` needs additional validation.
73 return ranks.isValid(var) && ranks.isValid(expr);
74}
75
76//===----------------------------------------------------------------------===//
77// `DimLvlMap` implementation.
78//===----------------------------------------------------------------------===//
79
80DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
81 ArrayRef<LvlSpec> lvlSpecs)
82 : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
83 mustPrintLvlVars(false) {
84 // First, check integrity of the variable-binding structure.
85 // NOTE: This establishes the invariant that calls to `VarSet::add`
86 // below cannot cause OOB errors.
87 assert(isWF());
88
89 VarSet usedVars(getRanks());
90 for (const auto &dimSpec : dimSpecs)
91 if (!dimSpec.canElideExpr())
92 usedVars.add(dimSpec.getExpr());
93 for (auto &lvlSpec : this->lvlSpecs) {
94 // Is this LvlVar used in any overt expression?
95 const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
96 // This LvlVar can be elided iff it isn't overtly used.
97 lvlSpec.setElideVar(!isUsed);
98 // If any LvlVar cannot be elided, then must forward-declare all LvlVars.
99 mustPrintLvlVars = mustPrintLvlVars || isUsed;
100 }
101}
102
103bool DimLvlMap::isWF() const {
104 const auto ranks = getRanks();
105 unsigned dimNum = 0;
106 for (const auto &dimSpec : dimSpecs)
107 if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks))
108 return false;
109 assert(dimNum == ranks.getDimRank());
110 unsigned lvlNum = 0;
111 for (const auto &lvlSpec : lvlSpecs)
112 if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks))
113 return false;
114 assert(lvlNum == ranks.getLvlRank());
115 return true;
116}
117
119 SmallVector<AffineExpr> lvlAffines;
120 lvlAffines.reserve(getLvlRank());
121 for (const auto &lvlSpec : lvlSpecs)
122 lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
123 auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
124 return map;
125}
126
128 SmallVector<AffineExpr> dimAffines;
129 dimAffines.reserve(getDimRank());
130 for (const auto &dimSpec : dimSpecs) {
131 auto expr = dimSpec.getExpr().getAffineExpr();
132 if (expr) {
133 dimAffines.push_back(expr);
134 }
135 }
136 auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
137 // If no lvlToDim map was passed in, returns a null AffineMap and infers it
138 // in SparseTensorEncodingAttr::parse.
139 if (dimAffines.empty())
140 return AffineMap();
141 return map;
142}
143
144//===----------------------------------------------------------------------===//
lhs
false
Parses a map_entries map type from a string format back into its numeric value.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::optional< SymVar > dyn_castSymVar() const
Definition DimLvlMap.cpp:23
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
Definition DimLvlMap.cpp:40
constexpr DimLvlExpr(ExprKind ek, AffineExpr expr)
Definition DimLvlMap.h:38
constexpr VarKind getAllowedVarKind() const
Definition DimLvlMap.h:65
std::optional< Var > dyn_castDimLvlVar() const
Definition DimLvlMap.cpp:33
DimLvlMap(unsigned symRank, ArrayRef< DimSpec > dimSpecs, ArrayRef< LvlSpec > lvlSpecs)
Definition DimLvlMap.cpp:80
AffineMap getDimToLvlMap(MLIRContext *context) const
AffineMap getLvlToDimMap(MLIRContext *context) const
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
Definition DimLvlMap.cpp:55
DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
Definition DimLvlMap.cpp:52
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
Definition DimLvlMap.cpp:71
LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
Definition DimLvlMap.cpp:65
constexpr bool isValid(Var var) const
Definition Var.h:233
Efficient representation of a set of Var.
Definition Var.h:242
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
Definition Var.cpp:77
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
Definition Var.cpp:87
A concrete variable, to be used in our variant of AffineExpr.
Definition Var.h:61
bool isUndefLT(LevelType lt)
Definition Enums.h:412
bool isValidLT(LevelType lt)
Definition Enums.h:433
Include the generated interface declarations.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238