MLIR 22.0.0git
ValueBoundsOpInterface.cpp
Go to the documentation of this file.
1//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
14#include "mlir/IR/Matchers.h"
17#include "llvm/ADT/APSInt.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20
21#define DEBUG_TYPE "value-bounds-op-interface"
22
23using namespace mlir;
26
27namespace mlir {
28#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
29} // namespace mlir
30
32 if (auto bbArg = dyn_cast<BlockArgument>(value))
33 return bbArg.getOwner()->getParentOp();
34 return value.getDefiningOp();
35}
36
40 : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
41 assert(offsets.size() == sizes.size() &&
42 "expected same number of offsets, sizes, strides");
43 assert(offsets.size() == strides.size() &&
44 "expected same number of offsets, sizes, strides");
45}
46
49 : mixedOffsets(offsets), mixedSizes(sizes) {
50 assert(offsets.size() == sizes.size() &&
51 "expected same number of offsets and sizes");
52 // Assume that all strides are 1.
53 if (offsets.empty())
54 return;
55 MLIRContext *ctx = offsets.front().getContext();
56 mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
57}
58
62
63/// If ofr is a constant integer or an IntegerAttr, return the integer.
64static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
65 // Case 1: Check for Constant integer.
66 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
67 APSInt intVal;
68 if (matchPattern(val, m_ConstantInt(&intVal)))
69 return intVal.getSExtValue();
70 return std::nullopt;
71 }
72 // Case 2: Check for IntegerAttr.
73 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
74 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
75 return intAttr.getValue().getSExtValue();
76 return std::nullopt;
77}
78
81
83 : Variable(static_cast<OpFoldResult>(indexValue)) {}
84
86 : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
87
89 std::optional<int64_t> dim) {
90 Builder b(ofr.getContext());
91 if (auto constInt = ::getConstantIntValue(ofr)) {
92 assert(!dim && "expected no dim for index-typed values");
93 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
94 b.getAffineConstantExpr(*constInt));
95 return;
96 }
97 Value value = cast<Value>(ofr);
98#ifndef NDEBUG
99 if (dim) {
100 assert(isa<ShapedType>(value.getType()) && "expected shaped type");
101 } else {
102 assert(value.getType().isIndex() && "expected index type");
103 }
104#endif // NDEBUG
105 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
106 b.getAffineSymbolExpr(0));
107 mapOperands.emplace_back(value, dim);
108}
109
111 ArrayRef<Variable> mapOperands) {
112 assert(map.getNumResults() == 1 && "expected single result");
113
114 // Turn all dims into symbols.
115 Builder b(map.getContext());
116 SmallVector<AffineExpr> dimReplacements, symReplacements;
117 for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
118 dimReplacements.push_back(b.getAffineSymbolExpr(i));
119 for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
120 symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
121 AffineMap tmpMap = map.replaceDimsAndSymbols(
122 dimReplacements, symReplacements, /*numResultDims=*/0,
123 /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
124
125 // Inline operands.
127 for (auto [index, var] : llvm::enumerate(mapOperands)) {
128 assert(var.map.getNumResults() == 1 && "expected single result");
129 assert(var.map.getNumDims() == 0 && "expected only symbols");
130 SmallVector<AffineExpr> symReplacements;
131 for (auto valueDim : var.mapOperands) {
132 auto *it = llvm::find(this->mapOperands, valueDim);
133 if (it != this->mapOperands.end()) {
134 // There is already a symbol for this operand.
135 symReplacements.push_back(b.getAffineSymbolExpr(
136 std::distance(this->mapOperands.begin(), it)));
137 } else {
138 // This is a new operand: add a new symbol.
139 symReplacements.push_back(
140 b.getAffineSymbolExpr(this->mapOperands.size()));
141 this->mapOperands.push_back(valueDim);
142 }
143 }
144 replacements[b.getAffineSymbolExpr(index)] =
145 var.map.getResult(0).replaceSymbols(symReplacements);
146 }
147 this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
148 /*numResultSyms=*/this->mapOperands.size());
149}
150
152 ValueRange mapOperands)
153 : Variable(map, llvm::map_to_vector(mapOperands,
154 [](Value v) { return Variable(v); })) {}
155
163
165
166#ifndef NDEBUG
167static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
168 if (value.getType().isIndex()) {
169 assert(!dim.has_value() && "invalid dim value");
170 } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) {
171 assert(*dim >= 0 && "invalid dim value");
172 if (shapedType.hasRank())
173 assert(*dim < shapedType.getRank() && "invalid dim value");
174 } else {
175 llvm_unreachable("unsupported type");
176 }
177}
178#endif // NDEBUG
179
181 AffineExpr expr) {
182 // Note: If `addConservativeSemiAffineBounds` is true then the bound
183 // computation function needs to handle the case that the constraints set
184 // could become empty. This is because the conservative bounds add assumptions
185 // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found
186 // not to hold, then the bound is invalid.
187 LogicalResult status = cstr.addBound(
188 type, pos,
189 AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr),
193 if (failed(status)) {
194 // Not all semi-affine expressions are not yet supported by
195 // FlatLinearConstraints. However, we can just ignore such failures here.
196 // Even without this bound, there may be enough information in the
197 // constraint system to compute the requested bound. In case this bound is
198 // actually needed, `computeBound` will return `failure`.
199 LDBG() << "Failed to add bound: " << expr << "\n";
200 }
201}
202
204 std::optional<int64_t> dim) {
205#ifndef NDEBUG
206 assertValidValueDim(value, dim);
207#endif // NDEBUG
208
209 // Check if the value/dim is statically known. In that case, an affine
210 // constant expression should be returned. This allows us to support
211 // multiplications with constants. (Multiplications of two columns in the
212 // constraint set is not supported.)
213 std::optional<int64_t> constSize = std::nullopt;
214 auto shapedType = dyn_cast<ShapedType>(value.getType());
215 if (shapedType) {
216 if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
217 constSize = shapedType.getDimSize(*dim);
218 } else if (auto constInt = ::getConstantIntValue(value)) {
219 constSize = *constInt;
220 }
221
222 // If the value/dim is already mapped, return the corresponding expression
223 // directly.
224 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
225 if (valueDimToPosition.contains(valueDim)) {
226 // If it is a constant, return an affine constant expression. Otherwise,
227 // return an affine expression that represents the respective column in the
228 // constraint set.
229 if (constSize)
230 return builder.getAffineConstantExpr(*constSize);
231 return getPosExpr(getPos(value, dim));
232 }
233
234 if (constSize) {
235 // Constant index value/dim: add column to the constraint set, add EQ bound
236 // and return an affine constant expression without pushing the newly added
237 // column to the worklist.
238 (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
239 if (shapedType)
240 bound(value)[*dim] == *constSize;
241 else
242 bound(value) == *constSize;
243 return builder.getAffineConstantExpr(*constSize);
244 }
245
246 // Dynamic value/dim: insert column to the constraint set and put it on the
247 // worklist. Return an affine expression that represents the newly inserted
248 // column in the constraint set.
249 return getPosExpr(insert(value, dim, /*isSymbol=*/true));
250}
251
253 if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
254 return getExpr(value, /*dim=*/std::nullopt);
255 auto constInt = ::getConstantIntValue(ofr);
256 assert(constInt.has_value() && "expected Integer constant");
257 return builder.getAffineConstantExpr(*constInt);
258}
259
261 return builder.getAffineConstantExpr(constant);
262}
263
265 std::optional<int64_t> dim,
266 bool isSymbol, bool addToWorklist) {
267#ifndef NDEBUG
268 assertValidValueDim(value, dim);
269#endif // NDEBUG
270
271 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
272 assert(!valueDimToPosition.contains(valueDim) && "already mapped");
273 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
274 : cstr.appendVar(VarKind::SetDim);
275 LDBG() << "Inserting constraint set column " << pos << " for: " << value
276 << " (dim: " << dim.value_or(kIndexValue)
277 << ", owner: " << getOwnerOfValue(value)->getName() << ")";
278 positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
279 // Update reverse mapping.
280 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
281 if (positionToValueDim[i].has_value())
283
284 if (addToWorklist) {
285 LDBG() << "Push to worklist: " << value
286 << " (dim: " << dim.value_or(kIndexValue) << ")";
287 worklist.push(pos);
288 }
289
290 return pos;
291}
292
294 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
295 : cstr.appendVar(VarKind::SetDim);
296 LDBG() << "Inserting anonymous constraint set column " << pos;
297 positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
298 // Update reverse mapping.
299 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
300 if (positionToValueDim[i].has_value())
302 return pos;
303}
304
306 const ValueDimList &operands,
307 bool isSymbol) {
308 assert(map.getNumResults() == 1 && "expected affine map with one result");
309 int64_t pos = insert(isSymbol);
310
311 // Add map and operands to the constraint set. Dimensions are converted to
312 // symbols. All operands are added to the worklist (unless they were already
313 // processed).
314 auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
315 return getExpr(v.first, v.second);
316 };
317 SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
318 llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
319 SmallVector<AffineExpr> symReplacements = llvm::to_vector(
320 llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
321 addBound(
323 map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
324
325 return pos;
326}
327
329 return insert(var.map, var.mapOperands, isSymbol);
330}
331
333 std::optional<int64_t> dim) const {
334#ifndef NDEBUG
335 assertValidValueDim(value, dim);
336 assert((isa<OpResult>(value) ||
337 cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
338 "unstructured control flow is not supported");
339#endif // NDEBUG
340 LDBG() << "Getting pos for: " << value
341 << " (dim: " << dim.value_or(kIndexValue)
342 << ", owner: " << getOwnerOfValue(value)->getName() << ")";
343 auto it =
344 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
345 assert(it != valueDimToPosition.end() && "expected mapped entry");
346 return it->second;
347}
348
350 assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
351 return pos < cstr.getNumDimVars()
352 ? builder.getAffineDimExpr(pos)
353 : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
354}
355
357 std::optional<int64_t> dim) const {
358 auto it =
359 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
360 return it != valueDimToPosition.end();
361}
362
364 LDBG() << "Processing value bounds worklist...";
365 while (!worklist.empty()) {
366 int64_t pos = worklist.front();
367 worklist.pop();
368 assert(positionToValueDim[pos].has_value() &&
369 "did not expect std::nullopt on worklist");
370 ValueDim valueDim = *positionToValueDim[pos];
371 Value value = valueDim.first;
372 int64_t dim = valueDim.second;
373
374 // Check for static dim size.
375 if (dim != kIndexValue) {
376 auto shapedType = cast<ShapedType>(value.getType());
377 if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) {
378 bound(value)[dim] == getExpr(shapedType.getDimSize(dim));
379 continue;
380 }
381 }
382
383 // Do not process any further if the stop condition is met.
384 auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
385 if (stopCondition(value, maybeDim, *this)) {
386 LDBG() << "Stop condition met for: " << value << " (dim: " << maybeDim
387 << ")";
388 continue;
389 }
390
391 // Query `ValueBoundsOpInterface` for constraints. New items may be added to
392 // the worklist.
393 auto valueBoundsOp =
394 dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
395 LDBG() << "Query value bounds for: " << value
396 << " (owner: " << getOwnerOfValue(value)->getName() << ")";
397 if (valueBoundsOp) {
398 if (dim == kIndexValue) {
399 valueBoundsOp.populateBoundsForIndexValue(value, *this);
400 } else {
401 valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
402 }
403 continue;
404 }
405 LDBG() << "--> ValueBoundsOpInterface not implemented";
406
407 // If the op does not implement `ValueBoundsOpInterface`, check if it
408 // implements the `DestinationStyleOpInterface`. OpResults of such ops are
409 // tied to OpOperands. Tied values have the same shape.
410 auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>();
411 if (!dstOp || dim == kIndexValue)
412 continue;
413 Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
414 bound(value)[dim] == getExpr(tiedOperand, dim);
415 }
416}
417
419 assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
420 "invalid position");
421 cstr.projectOut(pos);
422 if (positionToValueDim[pos].has_value()) {
423 bool erased = valueDimToPosition.erase(*positionToValueDim[pos]);
424 (void)erased;
425 assert(erased && "inconsistent reverse mapping");
426 }
427 positionToValueDim.erase(positionToValueDim.begin() + pos);
428 // Update reverse mapping.
429 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
430 if (positionToValueDim[i].has_value())
432}
433
435 function_ref<bool(ValueDim)> condition) {
436 int64_t nextPos = 0;
437 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
438 if (positionToValueDim[nextPos].has_value() &&
439 condition(*positionToValueDim[nextPos])) {
440 projectOut(nextPos);
441 // The column was projected out so another column is now at that position.
442 // Do not increase the counter.
443 } else {
444 ++nextPos;
445 }
446 }
447}
448
450 std::optional<int64_t> except) {
451 int64_t nextPos = 0;
452 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
453 if (positionToValueDim[nextPos].has_value() || except == nextPos) {
454 ++nextPos;
455 } else {
456 projectOut(nextPos);
457 // The column was projected out so another column is now at that position.
458 // Do not increase the counter.
459 }
460 }
461}
462
464 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
465 const Variable &var, StopConditionFn stopCondition, bool closedUB) {
466 MLIRContext *ctx = var.getContext();
467 int64_t ubAdjustment = closedUB ? 0 : 1;
468 Builder b(ctx);
469 mapOperands.clear();
470
471 // Process the backward slice of `value` (i.e., reverse use-def chain) until
472 // `stopCondition` is met.
474 int64_t pos = cstr.insert(var, /*isSymbol=*/false);
475 assert(pos == 0 && "expected first column");
476 cstr.processWorklist();
477
478 // Project out all variables (apart from `valueDim`) that do not match the
479 // stop condition.
480 cstr.projectOut([&](ValueDim p) {
481 auto maybeDim =
482 p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
483 return !stopCondition(p.first, maybeDim, cstr);
484 });
485 cstr.projectOutAnonymous(/*except=*/pos);
486
487 // Compute lower and upper bounds for `valueDim`.
488 SmallVector<AffineMap> lb(1), ub(1);
489 cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
490 /*closedUB=*/true);
491
492 // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
493 // case, no lower/upper bound can be computed at the moment.
494 // EQ, UB bounds: upper bound is needed.
495 if ((type != BoundType::LB) &&
496 (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
497 return failure();
498 // EQ, LB bounds: lower bound is needed.
499 if ((type != BoundType::UB) &&
500 (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
501 return failure();
502
503 // TODO: Generate an affine map with multiple results.
504 if (type != BoundType::LB)
505 assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
506 "multiple bounds not supported");
507 if (type != BoundType::UB)
508 assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
509 "multiple bounds not supported");
510
511 // EQ bound: lower and upper bound must match.
512 if (type == BoundType::EQ && ub[0] != lb[0])
513 return failure();
514
516 if (type == BoundType::EQ || type == BoundType::LB) {
517 bound = lb[0];
518 } else {
519 // Computed UB is a closed bound.
520 bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
521 ub[0].getResult(0) + ubAdjustment);
522 }
523
524 // Gather all SSA values that are used in the computed bound.
525 assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
526 "inconsistent mapping state");
527 SmallVector<AffineExpr> replacementDims, replacementSymbols;
528 int64_t numDims = 0, numSymbols = 0;
529 for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
530 // Skip `value`.
531 if (i == pos)
532 continue;
533 // Check if the position `i` is used in the generated bound. If so, it must
534 // be included in the generated affine.apply op.
535 bool used = false;
536 bool isDim = i < cstr.cstr.getNumDimVars();
537 if (isDim) {
538 if (bound.isFunctionOfDim(i))
539 used = true;
540 } else {
541 if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
542 used = true;
543 }
544
545 if (!used) {
546 // Not used: Remove dim/symbol from the result.
547 if (isDim) {
548 replacementDims.push_back(b.getAffineConstantExpr(0));
549 } else {
550 replacementSymbols.push_back(b.getAffineConstantExpr(0));
551 }
552 continue;
553 }
554
555 if (isDim) {
556 replacementDims.push_back(b.getAffineDimExpr(numDims++));
557 } else {
558 replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
559 }
560
561 assert(cstr.positionToValueDim[i].has_value() &&
562 "cannot build affine map in terms of anonymous column");
563 ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
564 Value value = valueDim.first;
565 int64_t dim = valueDim.second;
567 // An index-type value is used: can be used directly in the affine.apply
568 // op.
569 assert(value.getType().isIndex() && "expected index type");
570 mapOperands.push_back(std::make_pair(value, std::nullopt));
571 continue;
572 }
573
574 assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) &&
575 "expected dynamic dim");
576 mapOperands.push_back(std::make_pair(value, dim));
577 }
578
579 resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
580 numDims, numSymbols);
581 return success();
582}
583
585 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
586 const Variable &var, ValueDimList dependencies, bool closedUB) {
587 return computeBound(
588 resultMap, mapOperands, type, var,
589 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
590 return llvm::is_contained(dependencies, std::make_pair(v, d));
591 },
592 closedUB);
593}
594
596 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
597 const Variable &var, ValueRange independencies, bool closedUB) {
598 // Return "true" if the given value is independent of all values in
599 // `independencies`. I.e., neither the value itself nor any value in the
600 // backward slice (reverse use-def chain) is contained in `independencies`.
601 auto isIndependent = [&](Value v) {
603 DenseSet<Value> visited;
604 worklist.push_back(v);
605 while (!worklist.empty()) {
606 Value next = worklist.pop_back_val();
607 if (!visited.insert(next).second)
608 continue;
609 if (llvm::is_contained(independencies, next))
610 return false;
611 // TODO: DominanceInfo could be used to stop the traversal early.
612 Operation *op = next.getDefiningOp();
613 if (!op)
614 continue;
615 worklist.append(op->getOperands().begin(), op->getOperands().end());
616 }
617 return true;
618 };
619
620 // Reify bounds in terms of any independent values.
621 return computeBound(
622 resultMap, mapOperands, type, var,
623 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
624 return isIndependent(v);
625 },
626 closedUB);
627}
628
630 presburger::BoundType type, const Variable &var,
631 const StopConditionFn &stopCondition, bool closedUB) {
632 // Default stop condition if none was specified: Keep adding constraints until
633 // a bound could be computed.
634 int64_t pos = 0;
635 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
637 return cstr.cstr.getConstantBound64(type, pos).has_value();
638 };
639
641 var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
642 pos = cstr.populateConstraints(var.map, var.mapOperands);
643 assert(pos == 0 && "expected `map` is the first column");
644
645 // Compute constant bound for `valueDim`.
646 int64_t ubAdjustment = closedUB ? 0 : 1;
647 if (auto bound = cstr.cstr.getConstantBound64(type, pos))
648 return type == BoundType::UB ? *bound + ubAdjustment : *bound;
649 return failure();
650}
651
653 std::optional<int64_t> dim) {
654#ifndef NDEBUG
655 assertValidValueDim(value, dim);
656#endif // NDEBUG
657
658 // `getExpr` pushes the value/dim onto the worklist (unless it was already
659 // analyzed).
660 (void)getExpr(value, dim);
661 // Process all values/dims on the worklist. This may traverse and analyze
662 // additional IR, depending the current stop function.
664}
665
667 ValueDimList operands) {
668 int64_t pos = insert(map, std::move(operands), /*isSymbol=*/false);
669 // Process the backward slice of `operands` (i.e., reverse use-def chain)
670 // until `stopCondition` is met.
672 return pos;
673}
674
675FailureOr<int64_t>
677 std::optional<int64_t> dim1,
678 std::optional<int64_t> dim2) {
679#ifndef NDEBUG
680 assertValidValueDim(value1, dim1);
681 assertValidValueDim(value2, dim2);
682#endif // NDEBUG
683
684 Builder b(value1.getContext());
685 AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
686 b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
688 Variable(map, {{value1, dim1}, {value2, dim2}}));
689}
690
693 int64_t rhsPos) {
694 // This function returns "true" if "lhs CMP rhs" is proven to hold.
695 //
696 // Example for ComparisonOperator::LE and index-typed values: We would like to
697 // prove that lhs <= rhs. Proof by contradiction: add the inverse
698 // relation (lhs > rhs) to the constraint set and check if the resulting
699 // constraint set is "empty" (i.e. has no solution). In that case,
700 // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
701
702 // We cannot prove anything if the constraint set is already empty.
703 if (cstr.isEmpty()) {
704 LDBG() << "cannot compare value/dims: constraint system is already empty";
705 return false;
706 }
707
708 // EQ can be expressed as LE and GE.
709 if (cmp == EQ)
710 return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
711 comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
712
713 // Construct inequality.
714 SmallVector<int64_t> eq(cstr.getNumCols(), 0);
715 if (cmp == LT || cmp == LE) {
716 ++eq[lhsPos];
717 --eq[rhsPos];
718 } else if (cmp == GT || cmp == GE) {
719 --eq[lhsPos];
720 ++eq[rhsPos];
721 } else {
722 llvm_unreachable("unsupported comparison operator");
723 }
724 if (cmp == LE || cmp == GE)
725 eq[cstr.getNumCols() - 1] -= 1;
726
727 // Add inequality to the constraint set and check if it made the constraint
728 // set empty.
729 int64_t ineqPos = cstr.getNumInequalities();
730 cstr.addInequality(eq);
731 bool isEmpty = cstr.isEmpty();
732 cstr.removeInequality(ineqPos);
733 return isEmpty;
734}
735
737 int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
738 auto strongCmp = [&](ComparisonOperator cmp,
739 ComparisonOperator negCmp) -> FailureOr<bool> {
740 if (comparePos(lhsPos, cmp, rhsPos))
741 return true;
742 if (comparePos(lhsPos, negCmp, rhsPos))
743 return false;
744 return failure();
745 };
746 switch (cmp) {
756 std::optional<bool> le =
758 if (!le)
759 return failure();
760 if (!*le)
761 return false;
762 std::optional<bool> ge =
764 if (!ge)
765 return failure();
766 if (!*ge)
767 return false;
768 return true;
769 }
770 }
771 llvm_unreachable("invalid comparison operator");
772}
773
776 const Variable &rhs) {
777 int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
778 int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
779 return comparePos(lhsPos, cmp, rhsPos);
780}
781
784 const Variable &rhs) {
785 int64_t lhsPos = -1, rhsPos = -1;
786 auto stopCondition = [&](Value v, std::optional<int64_t> dim,
788 // Keep processing as long as lhs/rhs were not processed.
789 if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
790 size_t(rhsPos) >= cstr.positionToValueDim.size())
791 return false;
792 // Keep processing as long as the relation cannot be proven.
793 return cstr.comparePos(lhsPos, cmp, rhsPos);
794 };
796 lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
797 rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
798 return cstr.comparePos(lhsPos, cmp, rhsPos);
799}
800
803 const Variable &rhs) {
804 int64_t lhsPos = -1, rhsPos = -1;
805 auto stopCondition = [&](Value v, std::optional<int64_t> dim,
807 // Keep processing as long as lhs/rhs were not processed.
808 if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
809 size_t(rhsPos) >= cstr.positionToValueDim.size())
810 return false;
811 // Keep processing as long as the strong relation cannot be proven.
812 FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
813 return failed(ordered);
814 };
816 lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
817 rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
818 return cstr.strongComparePos(lhsPos, cmp, rhsPos);
819}
820
822 const Variable &var2) {
823 return strongCompare(var1, ComparisonOperator::EQ, var2);
824}
825
827 MLIRContext *ctx, const HyperrectangularSlice &slice1,
828 const HyperrectangularSlice &slice2) {
829 assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
830 "expected slices of same rank");
831 assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
832 "expected slices of same rank");
833 assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() &&
834 "expected slices of same rank");
835
836 Builder b(ctx);
837 bool foundUnknownBound = false;
838 for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
839 AffineMap map =
840 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
841 b.getAffineSymbolExpr(0) +
842 b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) -
843 b.getAffineSymbolExpr(3));
844 {
845 // Case 1: Slices are guaranteed to be non-overlapping if
846 // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
847 SmallVector<OpFoldResult> ofrOperands;
848 ofrOperands.push_back(slice1.getMixedOffsets()[i]);
849 ofrOperands.push_back(slice1.getMixedSizes()[i]);
850 ofrOperands.push_back(slice1.getMixedStrides()[i]);
851 ofrOperands.push_back(slice2.getMixedOffsets()[i]);
852 SmallVector<Value> valueOperands;
853 AffineMap foldedMap =
854 foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
855 FailureOr<int64_t> constBound = computeConstantBound(
856 presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
857 foundUnknownBound |= failed(constBound);
858 if (succeeded(constBound) && *constBound <= 0)
859 return false;
860 }
861 {
862 // Case 2: Slices are guaranteed to be non-overlapping if
863 // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
864 SmallVector<OpFoldResult> ofrOperands;
865 ofrOperands.push_back(slice2.getMixedOffsets()[i]);
866 ofrOperands.push_back(slice2.getMixedSizes()[i]);
867 ofrOperands.push_back(slice2.getMixedStrides()[i]);
868 ofrOperands.push_back(slice1.getMixedOffsets()[i]);
869 SmallVector<Value> valueOperands;
870 AffineMap foldedMap =
871 foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
872 FailureOr<int64_t> constBound = computeConstantBound(
873 presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
874 foundUnknownBound |= failed(constBound);
875 if (succeeded(constBound) && *constBound <= 0)
876 return false;
877 }
878 }
879
880 // If at least one bound could not be computed, we cannot be certain that the
881 // slices are really overlapping.
882 if (foundUnknownBound)
883 return failure();
884
885 // All bounds could be computed and none of the above cases applied.
886 // Therefore, the slices are guaranteed to overlap.
887 return true;
888}
889
891 MLIRContext *ctx, const HyperrectangularSlice &slice1,
892 const HyperrectangularSlice &slice2) {
893 assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
894 "expected slices of same rank");
895 assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
896 "expected slices of same rank");
897 assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() &&
898 "expected slices of same rank");
899
900 // The two slices are equivalent if all of their offsets, sizes and strides
901 // are equal. If equality cannot be determined for at least one of those
902 // values, equivalence cannot be determined and this function returns
903 // "failure".
904 for (auto [offset1, offset2] :
905 llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
906 FailureOr<bool> equal = areEqual(offset1, offset2);
907 if (failed(equal))
908 return failure();
909 if (!equal.value())
910 return false;
911 }
912 for (auto [size1, size2] :
913 llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
914 FailureOr<bool> equal = areEqual(size1, size2);
915 if (failed(equal))
916 return failure();
917 if (!equal.value())
918 return false;
919 }
920 for (auto [stride1, stride2] :
921 llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
922 FailureOr<bool> equal = areEqual(stride1, stride2);
923 if (failed(equal))
924 return failure();
925 if (!equal.value())
926 return false;
927 }
928 return true;
929}
930
932 llvm::errs() << "==========\nColumns:\n";
933 llvm::errs() << "(column\tdim\tvalue)\n";
934 for (auto [index, valueDim] : llvm::enumerate(positionToValueDim)) {
935 llvm::errs() << " " << index << "\t";
936 if (valueDim) {
937 if (valueDim->second == kIndexValue) {
938 llvm::errs() << "n/a\t";
939 } else {
940 llvm::errs() << valueDim->second << "\t";
941 }
942 llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " ";
943 if (OpResult result = dyn_cast<OpResult>(valueDim->first)) {
944 llvm::errs() << "(result " << result.getResultNumber() << ")";
945 } else {
946 llvm::errs() << "(bbarg "
947 << cast<BlockArgument>(valueDim->first).getArgNumber()
948 << ")";
949 }
950 llvm::errs() << "\n";
951 } else {
952 llvm::errs() << "n/a\tn/a\n";
953 }
954 }
955 llvm::errs() << "\nConstraint set:\n";
956 cstr.dump();
957 llvm::errs() << "==========\n";
958}
959
962 assert(!this->dim.has_value() && "dim was already set");
963 this->dim = dim;
964#ifndef NDEBUG
965 assertValidValueDim(value, this->dim);
966#endif // NDEBUG
967 return *this;
968}
969
971#ifndef NDEBUG
972 assertValidValueDim(value, this->dim);
973#endif // NDEBUG
974 cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr);
975}
976
980
984
986#ifndef NDEBUG
987 assertValidValueDim(value, this->dim);
988#endif // NDEBUG
989 cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr);
990}
991
993#ifndef NDEBUG
994 assertValidValueDim(value, this->dim);
995#endif // NDEBUG
996 cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr);
997}
998
1002
1006
1010
1014
1018
1022
1026
1030
1034
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Operation * getOwnerOfValue(Value value)
static void assertValidValueDim(Value value, std::optional< int64_t > dim)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
A hyperrectangular slice, represented as a list of offsets, sizes and strides.
HyperrectangularSlice(ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
ArrayRef< OpFoldResult > getMixedStrides() const
ArrayRef< OpFoldResult > getMixedSizes() const
ArrayRef< OpFoldResult > getMixedOffsets() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
MLIRContext * getContext() const
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
bool isIndex() const
Definition Types.cpp:54
Helper class that builds a bound for a shaped value dimension or index-typed value.
BoundBuilder & operator[](int64_t dim)
Specify a dimension, assuming that the underlying value is a shaped value.
A variable that can be added to the constraint set as a "column".
Variable(OpFoldResult ofr)
Construct a variable for an index-typed attribute or SSA value.
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
DenseMap< ValueDim, int64_t > valueDimToPosition
Reverse mapping of values/shape dimensions to columns.
void processWorklist()
Iteratively process all elements on the worklist until an index-typed value or shaped value meets sto...
bool addConservativeSemiAffineBounds
Should conservative bounds be added for semi-affine expressions.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
SmallVector< std::optional< ValueDim > > positionToValueDim
Mapping of columns to values/shape dimensions.
static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, ValueRange independencies, bool closedUB=false)
Compute a bound in that is independent of all values in independencies.
static FailureOr< bool > areEquivalentSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1, const HyperrectangularSlice &slice2)
Return "true" if the given slices are guaranteed to be equivalent.
void projectOut(int64_t pos)
Project out the given column in the constraint set.
std::function< bool( Value, std::optional< int64_t >, ValueBoundsConstraintSet &cstr)> StopConditionFn
The stop condition when traversing the backward slice of a shaped value/ index-type value.
ValueBoundsConstraintSet(MLIRContext *ctx, const StopConditionFn &stopCondition, bool addConservativeSemiAffineBounds=false)
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
static llvm::FailureOr< bool > strongCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
This function is similar to ValueBoundsConstraintSet::compare, except that it returns false if !...
void addBound(presburger::BoundType type, int64_t pos, AffineExpr expr)
Bound the given column in the underlying constraint set by the given expression.
StopConditionFn stopCondition
The current stop condition function.
ComparisonOperator
Comparison operator for ValueBoundsConstraintSet::compare.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
static LogicalResult computeBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, StopConditionFn stopCondition, bool closedUB=false)
Compute a bound for the given variable.
int64_t getPos(Value value, std::optional< int64_t > dim=std::nullopt) const
Return the column position of the given value/dimension.
int64_t insert(Value value, std::optional< int64_t > dim, bool isSymbol=true, bool addToWorklist=true)
Insert a value/dimension into the constraint set.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos)
Return "true" if, based on the current state of the constraint system, "lhs cmp rhs" was proven to ho...
void dump() const
Debugging only: Dump the constraint set and the column-to-value/dim mapping to llvm::errs.
std::queue< int64_t > worklist
Worklist of values/shape dimensions that have not been processed yet.
FlatLinearConstraints cstr
Constraint system of equalities and inequalities.
bool isMapped(Value value, std::optional< int64_t > dim=std::nullopt) const
Return "true" if the given value/dim is mapped (i.e., has a corresponding column in the constraint sy...
llvm::FailureOr< bool > strongComparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos)
Return "true" if, based on the current state of the constraint system, "lhs cmp rhs" was proven to ho...
AffineExpr getPosExpr(int64_t pos)
Return an affine expression that represents column pos in the constraint set.
void projectOutAnonymous(std::optional< int64_t > except=std::nullopt)
static FailureOr< bool > areOverlappingSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1, const HyperrectangularSlice &slice2)
Return "true" if the given slices are guaranteed to be overlapping.
std::pair< Value, int64_t > ValueDim
An index-typed value or the dimension of a shaped-type value.
void populateConstraints(Value value, std::optional< int64_t > dim)
Traverse the IR starting from the given value/dim and populate constraints as long as the stop condit...
Builder builder
Builder for constructing affine expressions.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
static constexpr int64_t kIndexValue
Dimension identifier to indicate a value is index-typed.
static LogicalResult computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, ValueDimList dependencies, bool closedUB=false)
Compute a bound in terms of the values/dimensions in dependencies.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
BoundType
The type of bound: equal, lower bound or upper bound.
VarKind
Kind of variable.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< std::pair< Value, std::optional< int64_t > > > ValueDimList
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152