MLIR 22.0.0git
ValueBoundsOpInterface.cpp
Go to the documentation of this file.
1//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
14#include "mlir/IR/Matchers.h"
17#include "llvm/ADT/APSInt.h"
18#include "llvm/Support/Debug.h"
19
20#define DEBUG_TYPE "value-bounds-op-interface"
21
22using namespace mlir;
25
26namespace mlir {
27#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
28} // namespace mlir
29
31 if (auto bbArg = dyn_cast<BlockArgument>(value))
32 return bbArg.getOwner()->getParentOp();
33 return value.getDefiningOp();
34}
35
39 : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
40 assert(offsets.size() == sizes.size() &&
41 "expected same number of offsets, sizes, strides");
42 assert(offsets.size() == strides.size() &&
43 "expected same number of offsets, sizes, strides");
44}
45
48 : mixedOffsets(offsets), mixedSizes(sizes) {
49 assert(offsets.size() == sizes.size() &&
50 "expected same number of offsets and sizes");
51 // Assume that all strides are 1.
52 if (offsets.empty())
53 return;
54 MLIRContext *ctx = offsets.front().getContext();
55 mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
56}
57
61
62/// If ofr is a constant integer or an IntegerAttr, return the integer.
63static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
64 // Case 1: Check for Constant integer.
65 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
66 APSInt intVal;
67 if (matchPattern(val, m_ConstantInt(&intVal)))
68 return intVal.getSExtValue();
69 return std::nullopt;
70 }
71 // Case 2: Check for IntegerAttr.
72 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
73 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
74 return intAttr.getValue().getSExtValue();
75 return std::nullopt;
76}
77
80
82 : Variable(static_cast<OpFoldResult>(indexValue)) {}
83
85 : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
86
88 std::optional<int64_t> dim) {
89 Builder b(ofr.getContext());
90 if (auto constInt = ::getConstantIntValue(ofr)) {
91 assert(!dim && "expected no dim for index-typed values");
92 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
93 b.getAffineConstantExpr(*constInt));
94 return;
95 }
96 Value value = cast<Value>(ofr);
97#ifndef NDEBUG
98 if (dim) {
99 assert(isa<ShapedType>(value.getType()) && "expected shaped type");
100 } else {
101 assert(value.getType().isIndex() && "expected index type");
102 }
103#endif // NDEBUG
104 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
105 b.getAffineSymbolExpr(0));
106 mapOperands.emplace_back(value, dim);
107}
108
110 ArrayRef<Variable> mapOperands) {
111 assert(map.getNumResults() == 1 && "expected single result");
112
113 // Turn all dims into symbols.
114 Builder b(map.getContext());
115 SmallVector<AffineExpr> dimReplacements, symReplacements;
116 for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
117 dimReplacements.push_back(b.getAffineSymbolExpr(i));
118 for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
119 symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
120 AffineMap tmpMap = map.replaceDimsAndSymbols(
121 dimReplacements, symReplacements, /*numResultDims=*/0,
122 /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
123
124 // Inline operands.
126 for (auto [index, var] : llvm::enumerate(mapOperands)) {
127 assert(var.map.getNumResults() == 1 && "expected single result");
128 assert(var.map.getNumDims() == 0 && "expected only symbols");
129 SmallVector<AffineExpr> symReplacements;
130 for (auto valueDim : var.mapOperands) {
131 auto it = llvm::find(this->mapOperands, valueDim);
132 if (it != this->mapOperands.end()) {
133 // There is already a symbol for this operand.
134 symReplacements.push_back(b.getAffineSymbolExpr(
135 std::distance(this->mapOperands.begin(), it)));
136 } else {
137 // This is a new operand: add a new symbol.
138 symReplacements.push_back(
139 b.getAffineSymbolExpr(this->mapOperands.size()));
140 this->mapOperands.push_back(valueDim);
141 }
142 }
143 replacements[b.getAffineSymbolExpr(index)] =
144 var.map.getResult(0).replaceSymbols(symReplacements);
145 }
146 this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
147 /*numResultSyms=*/this->mapOperands.size());
148}
149
151 ValueRange mapOperands)
152 : Variable(map, llvm::map_to_vector(mapOperands,
153 [](Value v) { return Variable(v); })) {}
154
162
164
165#ifndef NDEBUG
166static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
167 if (value.getType().isIndex()) {
168 assert(!dim.has_value() && "invalid dim value");
169 } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) {
170 assert(*dim >= 0 && "invalid dim value");
171 if (shapedType.hasRank())
172 assert(*dim < shapedType.getRank() && "invalid dim value");
173 } else {
174 llvm_unreachable("unsupported type");
175 }
176}
177#endif // NDEBUG
178
180 AffineExpr expr) {
181 // Note: If `addConservativeSemiAffineBounds` is true then the bound
182 // computation function needs to handle the case that the constraints set
183 // could become empty. This is because the conservative bounds add assumptions
184 // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found
185 // not to hold, then the bound is invalid.
186 LogicalResult status = cstr.addBound(
187 type, pos,
188 AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr),
192 if (failed(status)) {
193 // Not all semi-affine expressions are not yet supported by
194 // FlatLinearConstraints. However, we can just ignore such failures here.
195 // Even without this bound, there may be enough information in the
196 // constraint system to compute the requested bound. In case this bound is
197 // actually needed, `computeBound` will return `failure`.
198 LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr << "\n");
199 }
200}
201
203 std::optional<int64_t> dim) {
204#ifndef NDEBUG
205 assertValidValueDim(value, dim);
206#endif // NDEBUG
207
208 // Check if the value/dim is statically known. In that case, an affine
209 // constant expression should be returned. This allows us to support
210 // multiplications with constants. (Multiplications of two columns in the
211 // constraint set is not supported.)
212 std::optional<int64_t> constSize = std::nullopt;
213 auto shapedType = dyn_cast<ShapedType>(value.getType());
214 if (shapedType) {
215 if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
216 constSize = shapedType.getDimSize(*dim);
217 } else if (auto constInt = ::getConstantIntValue(value)) {
218 constSize = *constInt;
219 }
220
221 // If the value/dim is already mapped, return the corresponding expression
222 // directly.
223 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
224 if (valueDimToPosition.contains(valueDim)) {
225 // If it is a constant, return an affine constant expression. Otherwise,
226 // return an affine expression that represents the respective column in the
227 // constraint set.
228 if (constSize)
229 return builder.getAffineConstantExpr(*constSize);
230 return getPosExpr(getPos(value, dim));
231 }
232
233 if (constSize) {
234 // Constant index value/dim: add column to the constraint set, add EQ bound
235 // and return an affine constant expression without pushing the newly added
236 // column to the worklist.
237 (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
238 if (shapedType)
239 bound(value)[*dim] == *constSize;
240 else
241 bound(value) == *constSize;
242 return builder.getAffineConstantExpr(*constSize);
243 }
244
245 // Dynamic value/dim: insert column to the constraint set and put it on the
246 // worklist. Return an affine expression that represents the newly inserted
247 // column in the constraint set.
248 return getPosExpr(insert(value, dim, /*isSymbol=*/true));
249}
250
252 if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
253 return getExpr(value, /*dim=*/std::nullopt);
254 auto constInt = ::getConstantIntValue(ofr);
255 assert(constInt.has_value() && "expected Integer constant");
256 return builder.getAffineConstantExpr(*constInt);
257}
258
260 return builder.getAffineConstantExpr(constant);
261}
262
264 std::optional<int64_t> dim,
265 bool isSymbol, bool addToWorklist) {
266#ifndef NDEBUG
267 assertValidValueDim(value, dim);
268#endif // NDEBUG
269
270 ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
271 assert(!valueDimToPosition.contains(valueDim) && "already mapped");
272 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
273 : cstr.appendVar(VarKind::SetDim);
274 LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
275 << " for: " << value
276 << " (dim: " << dim.value_or(kIndexValue)
277 << ", owner: " << getOwnerOfValue(value)->getName()
278 << ")\n");
279 positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
280 // Update reverse mapping.
281 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
282 if (positionToValueDim[i].has_value())
284
285 if (addToWorklist) {
286 LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
287 << " (dim: " << dim.value_or(kIndexValue) << ")\n");
288 worklist.push(pos);
289 }
290
291 return pos;
292}
293
295 int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
296 : cstr.appendVar(VarKind::SetDim);
297 LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
298 << "\n");
299 positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
300 // Update reverse mapping.
301 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
302 if (positionToValueDim[i].has_value())
304 return pos;
305}
306
308 const ValueDimList &operands,
309 bool isSymbol) {
310 assert(map.getNumResults() == 1 && "expected affine map with one result");
311 int64_t pos = insert(isSymbol);
312
313 // Add map and operands to the constraint set. Dimensions are converted to
314 // symbols. All operands are added to the worklist (unless they were already
315 // processed).
316 auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
317 return getExpr(v.first, v.second);
318 };
319 SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
320 llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
321 SmallVector<AffineExpr> symReplacements = llvm::to_vector(
322 llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
323 addBound(
325 map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
326
327 return pos;
328}
331 return insert(var.map, var.mapOperands, isSymbol);
332}
333
335 std::optional<int64_t> dim) const {
336#ifndef NDEBUG
337 assertValidValueDim(value, dim);
338 assert((isa<OpResult>(value) ||
339 cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
340 "unstructured control flow is not supported");
341#endif // NDEBUG
342 LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
343 << " (dim: " << dim.value_or(kIndexValue)
344 << ", owner: " << getOwnerOfValue(value)->getName()
345 << ")\n");
346 auto it =
347 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
348 assert(it != valueDimToPosition.end() && "expected mapped entry");
349 return it->second;
350}
351
353 assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
354 return pos < cstr.getNumDimVars()
355 ? builder.getAffineDimExpr(pos)
356 : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
357}
358
360 std::optional<int64_t> dim) const {
361 auto it =
362 valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
363 return it != valueDimToPosition.end();
364}
365
367 LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
368 while (!worklist.empty()) {
369 int64_t pos = worklist.front();
370 worklist.pop();
371 assert(positionToValueDim[pos].has_value() &&
372 "did not expect std::nullopt on worklist");
373 ValueDim valueDim = *positionToValueDim[pos];
374 Value value = valueDim.first;
375 int64_t dim = valueDim.second;
376
377 // Check for static dim size.
378 if (dim != kIndexValue) {
379 auto shapedType = cast<ShapedType>(value.getType());
380 if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) {
381 bound(value)[dim] == getExpr(shapedType.getDimSize(dim));
382 continue;
383 }
384 }
385
386 // Do not process any further if the stop condition is met.
387 auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
388 if (stopCondition(value, maybeDim, *this)) {
389 LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
390 << " (dim: " << maybeDim << ")\n");
391 continue;
392 }
393
394 // Query `ValueBoundsOpInterface` for constraints. New items may be added to
395 // the worklist.
396 auto valueBoundsOp =
397 dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
398 LLVM_DEBUG(llvm::dbgs()
399 << "Query value bounds for: " << value
400 << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
401 if (valueBoundsOp) {
402 if (dim == kIndexValue) {
403 valueBoundsOp.populateBoundsForIndexValue(value, *this);
404 } else {
405 valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
406 }
407 continue;
408 }
409 LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
410
411 // If the op does not implement `ValueBoundsOpInterface`, check if it
412 // implements the `DestinationStyleOpInterface`. OpResults of such ops are
413 // tied to OpOperands. Tied values have the same shape.
414 auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>();
415 if (!dstOp || dim == kIndexValue)
416 continue;
417 Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
418 bound(value)[dim] == getExpr(tiedOperand, dim);
419 }
420}
421
423 assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
424 "invalid position");
425 cstr.projectOut(pos);
426 if (positionToValueDim[pos].has_value()) {
427 bool erased = valueDimToPosition.erase(*positionToValueDim[pos]);
428 (void)erased;
429 assert(erased && "inconsistent reverse mapping");
430 }
431 positionToValueDim.erase(positionToValueDim.begin() + pos);
432 // Update reverse mapping.
433 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
434 if (positionToValueDim[i].has_value())
436}
437
439 function_ref<bool(ValueDim)> condition) {
440 int64_t nextPos = 0;
441 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
442 if (positionToValueDim[nextPos].has_value() &&
443 condition(*positionToValueDim[nextPos])) {
444 projectOut(nextPos);
445 // The column was projected out so another column is now at that position.
446 // Do not increase the counter.
447 } else {
448 ++nextPos;
449 }
450 }
451}
452
454 std::optional<int64_t> except) {
455 int64_t nextPos = 0;
456 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
457 if (positionToValueDim[nextPos].has_value() || except == nextPos) {
458 ++nextPos;
459 } else {
460 projectOut(nextPos);
461 // The column was projected out so another column is now at that position.
462 // Do not increase the counter.
463 }
464 }
465}
466
468 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
469 const Variable &var, StopConditionFn stopCondition, bool closedUB) {
470 MLIRContext *ctx = var.getContext();
471 int64_t ubAdjustment = closedUB ? 0 : 1;
472 Builder b(ctx);
473 mapOperands.clear();
474
475 // Process the backward slice of `value` (i.e., reverse use-def chain) until
476 // `stopCondition` is met.
478 int64_t pos = cstr.insert(var, /*isSymbol=*/false);
479 assert(pos == 0 && "expected first column");
480 cstr.processWorklist();
481
482 // Project out all variables (apart from `valueDim`) that do not match the
483 // stop condition.
484 cstr.projectOut([&](ValueDim p) {
485 auto maybeDim =
486 p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
487 return !stopCondition(p.first, maybeDim, cstr);
488 });
489 cstr.projectOutAnonymous(/*except=*/pos);
490
491 // Compute lower and upper bounds for `valueDim`.
492 SmallVector<AffineMap> lb(1), ub(1);
493 cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
494 /*closedUB=*/true);
495
496 // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
497 // case, no lower/upper bound can be computed at the moment.
498 // EQ, UB bounds: upper bound is needed.
499 if ((type != BoundType::LB) &&
500 (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
501 return failure();
502 // EQ, LB bounds: lower bound is needed.
503 if ((type != BoundType::UB) &&
504 (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
505 return failure();
506
507 // TODO: Generate an affine map with multiple results.
508 if (type != BoundType::LB)
509 assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
510 "multiple bounds not supported");
511 if (type != BoundType::UB)
512 assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
513 "multiple bounds not supported");
514
515 // EQ bound: lower and upper bound must match.
516 if (type == BoundType::EQ && ub[0] != lb[0])
517 return failure();
518
520 if (type == BoundType::EQ || type == BoundType::LB) {
521 bound = lb[0];
522 } else {
523 // Computed UB is a closed bound.
524 bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(),
525 ub[0].getResult(0) + ubAdjustment);
526 }
527
528 // Gather all SSA values that are used in the computed bound.
529 assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
530 "inconsistent mapping state");
531 SmallVector<AffineExpr> replacementDims, replacementSymbols;
532 int64_t numDims = 0, numSymbols = 0;
533 for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
534 // Skip `value`.
535 if (i == pos)
536 continue;
537 // Check if the position `i` is used in the generated bound. If so, it must
538 // be included in the generated affine.apply op.
539 bool used = false;
540 bool isDim = i < cstr.cstr.getNumDimVars();
541 if (isDim) {
542 if (bound.isFunctionOfDim(i))
543 used = true;
544 } else {
545 if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
546 used = true;
547 }
548
549 if (!used) {
550 // Not used: Remove dim/symbol from the result.
551 if (isDim) {
552 replacementDims.push_back(b.getAffineConstantExpr(0));
553 } else {
554 replacementSymbols.push_back(b.getAffineConstantExpr(0));
555 }
556 continue;
557 }
558
559 if (isDim) {
560 replacementDims.push_back(b.getAffineDimExpr(numDims++));
561 } else {
562 replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
563 }
564
565 assert(cstr.positionToValueDim[i].has_value() &&
566 "cannot build affine map in terms of anonymous column");
567 ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
568 Value value = valueDim.first;
569 int64_t dim = valueDim.second;
571 // An index-type value is used: can be used directly in the affine.apply
572 // op.
573 assert(value.getType().isIndex() && "expected index type");
574 mapOperands.push_back(std::make_pair(value, std::nullopt));
575 continue;
576 }
577
578 assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) &&
579 "expected dynamic dim");
580 mapOperands.push_back(std::make_pair(value, dim));
581 }
582
583 resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols,
584 numDims, numSymbols);
585 return success();
586}
587
589 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
590 const Variable &var, ValueDimList dependencies, bool closedUB) {
591 return computeBound(
592 resultMap, mapOperands, type, var,
593 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
594 return llvm::is_contained(dependencies, std::make_pair(v, d));
595 },
596 closedUB);
597}
598
600 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
601 const Variable &var, ValueRange independencies, bool closedUB) {
602 // Return "true" if the given value is independent of all values in
603 // `independencies`. I.e., neither the value itself nor any value in the
604 // backward slice (reverse use-def chain) is contained in `independencies`.
605 auto isIndependent = [&](Value v) {
607 DenseSet<Value> visited;
608 worklist.push_back(v);
609 while (!worklist.empty()) {
610 Value next = worklist.pop_back_val();
611 if (!visited.insert(next).second)
612 continue;
613 if (llvm::is_contained(independencies, next))
614 return false;
615 // TODO: DominanceInfo could be used to stop the traversal early.
616 Operation *op = next.getDefiningOp();
617 if (!op)
618 continue;
619 worklist.append(op->getOperands().begin(), op->getOperands().end());
620 }
621 return true;
622 };
623
624 // Reify bounds in terms of any independent values.
625 return computeBound(
626 resultMap, mapOperands, type, var,
627 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
628 return isIndependent(v);
629 },
630 closedUB);
631}
632
634 presburger::BoundType type, const Variable &var,
635 const StopConditionFn &stopCondition, bool closedUB) {
636 // Default stop condition if none was specified: Keep adding constraints until
637 // a bound could be computed.
638 int64_t pos = 0;
639 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
641 return cstr.cstr.getConstantBound64(type, pos).has_value();
642 };
643
645 var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
646 pos = cstr.populateConstraints(var.map, var.mapOperands);
647 assert(pos == 0 && "expected `map` is the first column");
648
649 // Compute constant bound for `valueDim`.
650 int64_t ubAdjustment = closedUB ? 0 : 1;
651 if (auto bound = cstr.cstr.getConstantBound64(type, pos))
652 return type == BoundType::UB ? *bound + ubAdjustment : *bound;
653 return failure();
654}
655
657 std::optional<int64_t> dim) {
658#ifndef NDEBUG
659 assertValidValueDim(value, dim);
660#endif // NDEBUG
661
662 // `getExpr` pushes the value/dim onto the worklist (unless it was already
663 // analyzed).
664 (void)getExpr(value, dim);
665 // Process all values/dims on the worklist. This may traverse and analyze
666 // additional IR, depending the current stop function.
668}
669
671 ValueDimList operands) {
672 int64_t pos = insert(map, std::move(operands), /*isSymbol=*/false);
673 // Process the backward slice of `operands` (i.e., reverse use-def chain)
674 // until `stopCondition` is met.
676 return pos;
677}
678
679FailureOr<int64_t>
681 std::optional<int64_t> dim1,
682 std::optional<int64_t> dim2) {
683#ifndef NDEBUG
684 assertValidValueDim(value1, dim1);
685 assertValidValueDim(value2, dim2);
686#endif // NDEBUG
687
688 Builder b(value1.getContext());
689 AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
690 b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
692 Variable(map, {{value1, dim1}, {value2, dim2}}));
693}
694
697 int64_t rhsPos) {
698 // This function returns "true" if "lhs CMP rhs" is proven to hold.
699 //
700 // Example for ComparisonOperator::LE and index-typed values: We would like to
701 // prove that lhs <= rhs. Proof by contradiction: add the inverse
702 // relation (lhs > rhs) to the constraint set and check if the resulting
703 // constraint set is "empty" (i.e. has no solution). In that case,
704 // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
705
706 // We cannot prove anything if the constraint set is already empty.
707 if (cstr.isEmpty()) {
708 LLVM_DEBUG(
709 llvm::dbgs()
710 << "cannot compare value/dims: constraint system is already empty");
711 return false;
712 }
713
714 // EQ can be expressed as LE and GE.
715 if (cmp == EQ)
716 return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
717 comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
718
719 // Construct inequality.
720 SmallVector<int64_t> eq(cstr.getNumCols(), 0);
721 if (cmp == LT || cmp == LE) {
722 ++eq[lhsPos];
723 --eq[rhsPos];
724 } else if (cmp == GT || cmp == GE) {
725 --eq[lhsPos];
726 ++eq[rhsPos];
727 } else {
728 llvm_unreachable("unsupported comparison operator");
729 }
730 if (cmp == LE || cmp == GE)
731 eq[cstr.getNumCols() - 1] -= 1;
732
733 // Add inequality to the constraint set and check if it made the constraint
734 // set empty.
735 int64_t ineqPos = cstr.getNumInequalities();
736 cstr.addInequality(eq);
737 bool isEmpty = cstr.isEmpty();
738 cstr.removeInequality(ineqPos);
739 return isEmpty;
740}
741
743 int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
744 auto strongCmp = [&](ComparisonOperator cmp,
745 ComparisonOperator negCmp) -> FailureOr<bool> {
746 if (comparePos(lhsPos, cmp, rhsPos))
747 return true;
748 if (comparePos(lhsPos, negCmp, rhsPos))
749 return false;
750 return failure();
751 };
752 switch (cmp) {
762 std::optional<bool> le =
764 if (!le)
765 return failure();
766 if (!*le)
767 return false;
768 std::optional<bool> ge =
770 if (!ge)
771 return failure();
772 if (!*ge)
773 return false;
774 return true;
775 }
776 }
777 llvm_unreachable("invalid comparison operator");
778}
779
782 const Variable &rhs) {
783 int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
784 int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
785 return comparePos(lhsPos, cmp, rhsPos);
786}
787
790 const Variable &rhs) {
791 int64_t lhsPos = -1, rhsPos = -1;
792 auto stopCondition = [&](Value v, std::optional<int64_t> dim,
794 // Keep processing as long as lhs/rhs were not processed.
795 if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
796 size_t(rhsPos) >= cstr.positionToValueDim.size())
797 return false;
798 // Keep processing as long as the relation cannot be proven.
799 return cstr.comparePos(lhsPos, cmp, rhsPos);
800 };
802 lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
803 rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
804 return cstr.comparePos(lhsPos, cmp, rhsPos);
805}
806
809 const Variable &rhs) {
810 int64_t lhsPos = -1, rhsPos = -1;
811 auto stopCondition = [&](Value v, std::optional<int64_t> dim,
813 // Keep processing as long as lhs/rhs were not processed.
814 if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
815 size_t(rhsPos) >= cstr.positionToValueDim.size())
816 return false;
817 // Keep processing as long as the strong relation cannot be proven.
818 FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
819 return failed(ordered);
820 };
822 lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
823 rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
824 return cstr.strongComparePos(lhsPos, cmp, rhsPos);
825}
826
828 const Variable &var2) {
829 return strongCompare(var1, ComparisonOperator::EQ, var2);
830}
831
833 MLIRContext *ctx, const HyperrectangularSlice &slice1,
834 const HyperrectangularSlice &slice2) {
835 assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
836 "expected slices of same rank");
837 assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
838 "expected slices of same rank");
839 assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() &&
840 "expected slices of same rank");
841
842 Builder b(ctx);
843 bool foundUnknownBound = false;
844 for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
845 AffineMap map =
846 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
847 b.getAffineSymbolExpr(0) +
848 b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) -
849 b.getAffineSymbolExpr(3));
850 {
851 // Case 1: Slices are guaranteed to be non-overlapping if
852 // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
853 SmallVector<OpFoldResult> ofrOperands;
854 ofrOperands.push_back(slice1.getMixedOffsets()[i]);
855 ofrOperands.push_back(slice1.getMixedSizes()[i]);
856 ofrOperands.push_back(slice1.getMixedStrides()[i]);
857 ofrOperands.push_back(slice2.getMixedOffsets()[i]);
858 SmallVector<Value> valueOperands;
859 AffineMap foldedMap =
860 foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
861 FailureOr<int64_t> constBound = computeConstantBound(
862 presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
863 foundUnknownBound |= failed(constBound);
864 if (succeeded(constBound) && *constBound <= 0)
865 return false;
866 }
867 {
868 // Case 2: Slices are guaranteed to be non-overlapping if
869 // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
870 SmallVector<OpFoldResult> ofrOperands;
871 ofrOperands.push_back(slice2.getMixedOffsets()[i]);
872 ofrOperands.push_back(slice2.getMixedSizes()[i]);
873 ofrOperands.push_back(slice2.getMixedStrides()[i]);
874 ofrOperands.push_back(slice1.getMixedOffsets()[i]);
875 SmallVector<Value> valueOperands;
876 AffineMap foldedMap =
877 foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
878 FailureOr<int64_t> constBound = computeConstantBound(
879 presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
880 foundUnknownBound |= failed(constBound);
881 if (succeeded(constBound) && *constBound <= 0)
882 return false;
883 }
884 }
885
886 // If at least one bound could not be computed, we cannot be certain that the
887 // slices are really overlapping.
888 if (foundUnknownBound)
889 return failure();
890
891 // All bounds could be computed and none of the above cases applied.
892 // Therefore, the slices are guaranteed to overlap.
893 return true;
894}
895
897 MLIRContext *ctx, const HyperrectangularSlice &slice1,
898 const HyperrectangularSlice &slice2) {
899 assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
900 "expected slices of same rank");
901 assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
902 "expected slices of same rank");
903 assert(slice1.getMixedStrides().size() == slice2.getMixedStrides().size() &&
904 "expected slices of same rank");
905
906 // The two slices are equivalent if all of their offsets, sizes and strides
907 // are equal. If equality cannot be determined for at least one of those
908 // values, equivalence cannot be determined and this function returns
909 // "failure".
910 for (auto [offset1, offset2] :
911 llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
912 FailureOr<bool> equal = areEqual(offset1, offset2);
913 if (failed(equal))
914 return failure();
915 if (!equal.value())
916 return false;
917 }
918 for (auto [size1, size2] :
919 llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
920 FailureOr<bool> equal = areEqual(size1, size2);
921 if (failed(equal))
922 return failure();
923 if (!equal.value())
924 return false;
925 }
926 for (auto [stride1, stride2] :
927 llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
928 FailureOr<bool> equal = areEqual(stride1, stride2);
929 if (failed(equal))
930 return failure();
931 if (!equal.value())
932 return false;
933 }
934 return true;
935}
936
938 llvm::errs() << "==========\nColumns:\n";
939 llvm::errs() << "(column\tdim\tvalue)\n";
940 for (auto [index, valueDim] : llvm::enumerate(positionToValueDim)) {
941 llvm::errs() << " " << index << "\t";
942 if (valueDim) {
943 if (valueDim->second == kIndexValue) {
944 llvm::errs() << "n/a\t";
945 } else {
946 llvm::errs() << valueDim->second << "\t";
947 }
948 llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " ";
949 if (OpResult result = dyn_cast<OpResult>(valueDim->first)) {
950 llvm::errs() << "(result " << result.getResultNumber() << ")";
951 } else {
952 llvm::errs() << "(bbarg "
953 << cast<BlockArgument>(valueDim->first).getArgNumber()
954 << ")";
955 }
956 llvm::errs() << "\n";
957 } else {
958 llvm::errs() << "n/a\tn/a\n";
959 }
960 }
961 llvm::errs() << "\nConstraint set:\n";
962 cstr.dump();
963 llvm::errs() << "==========\n";
964}
965
968 assert(!this->dim.has_value() && "dim was already set");
969 this->dim = dim;
970#ifndef NDEBUG
971 assertValidValueDim(value, this->dim);
972#endif // NDEBUG
973 return *this;
974}
975
977#ifndef NDEBUG
978 assertValidValueDim(value, this->dim);
979#endif // NDEBUG
980 cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr);
981}
982
986
990
992#ifndef NDEBUG
993 assertValidValueDim(value, this->dim);
994#endif // NDEBUG
995 cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr);
996}
997
999#ifndef NDEBUG
1000 assertValidValueDim(value, this->dim);
1001#endif // NDEBUG
1002 cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr);
1003}
1004
1008
1012
1016
1020
1024
1028
1032
1036
1040
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Operation * getOwnerOfValue(Value value)
static void assertValidValueDim(Value value, std::optional< int64_t > dim)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
A hyperrectangular slice, represented as a list of offsets, sizes and strides.
HyperrectangularSlice(ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
ArrayRef< OpFoldResult > getMixedStrides() const
ArrayRef< OpFoldResult > getMixedSizes() const
ArrayRef< OpFoldResult > getMixedOffsets() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
MLIRContext * getContext() const
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
bool isIndex() const
Definition Types.cpp:54
Helper class that builds a bound for a shaped value dimension or index-typed value.
BoundBuilder & operator[](int64_t dim)
Specify a dimension, assuming that the underlying value is a shaped value.
A variable that can be added to the constraint set as a "column".
Variable(OpFoldResult ofr)
Construct a variable for an index-typed attribute or SSA value.
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
DenseMap< ValueDim, int64_t > valueDimToPosition
Reverse mapping of values/shape dimensions to columns.
void processWorklist()
Iteratively process all elements on the worklist until an index-typed value or shaped value meets sto...
bool addConservativeSemiAffineBounds
Should conservative bounds be added for semi-affine expressions.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
SmallVector< std::optional< ValueDim > > positionToValueDim
Mapping of columns to values/shape dimensions.
static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, ValueRange independencies, bool closedUB=false)
Compute a bound in that is independent of all values in independencies.
static FailureOr< bool > areEquivalentSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1, const HyperrectangularSlice &slice2)
Return "true" if the given slices are guaranteed to be equivalent.
void projectOut(int64_t pos)
Project out the given column in the constraint set.
std::function< bool( Value, std::optional< int64_t >, ValueBoundsConstraintSet &cstr)> StopConditionFn
The stop condition when traversing the backward slice of a shaped value/ index-type value.
ValueBoundsConstraintSet(MLIRContext *ctx, const StopConditionFn &stopCondition, bool addConservativeSemiAffineBounds=false)
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
static llvm::FailureOr< bool > strongCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
This function is similar to ValueBoundsConstraintSet::compare, except that it returns false if !...
void addBound(presburger::BoundType type, int64_t pos, AffineExpr expr)
Bound the given column in the underlying constraint set by the given expression.
StopConditionFn stopCondition
The current stop condition function.
ComparisonOperator
Comparison operator for ValueBoundsConstraintSet::compare.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
static LogicalResult computeBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, StopConditionFn stopCondition, bool closedUB=false)
Compute a bound for the given variable.
int64_t getPos(Value value, std::optional< int64_t > dim=std::nullopt) const
Return the column position of the given value/dimension.
int64_t insert(Value value, std::optional< int64_t > dim, bool isSymbol=true, bool addToWorklist=true)
Insert a value/dimension into the constraint set.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos)
Return "true" if, based on the current state of the constraint system, "lhs cmp rhs" was proven to ho...
void dump() const
Debugging only: Dump the constraint set and the column-to-value/dim mapping to llvm::errs.
std::queue< int64_t > worklist
Worklist of values/shape dimensions that have not been processed yet.
FlatLinearConstraints cstr
Constraint system of equalities and inequalities.
bool isMapped(Value value, std::optional< int64_t > dim=std::nullopt) const
Return "true" if the given value/dim is mapped (i.e., has a corresponding column in the constraint sy...
llvm::FailureOr< bool > strongComparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos)
Return "true" if, based on the current state of the constraint system, "lhs cmp rhs" was proven to ho...
AffineExpr getPosExpr(int64_t pos)
Return an affine expression that represents column pos in the constraint set.
void projectOutAnonymous(std::optional< int64_t > except=std::nullopt)
static FailureOr< bool > areOverlappingSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1, const HyperrectangularSlice &slice2)
Return "true" if the given slices are guaranteed to be overlapping.
std::pair< Value, int64_t > ValueDim
An index-typed value or the dimension of a shaped-type value.
void populateConstraints(Value value, std::optional< int64_t > dim)
Traverse the IR starting from the given value/dim and populate constraints as long as the stop condit...
Builder builder
Builder for constructing affine expressions.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
static constexpr int64_t kIndexValue
Dimension identifier to indicate a value is index-typed.
static LogicalResult computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, ValueDimList dependencies, bool closedUB=false)
Compute a bound in terms of the values/dimensions in dependencies.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
BoundType
The type of bound: equal, lower bound or upper bound.
VarKind
Kind of variable.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< std::pair< Value, std::optional< int64_t > > > ValueDimList
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152