MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 
18 using namespace mlir;
19 
20 namespace mlir {
21 namespace tensor {
22 namespace {
23 /// Generate a runtime check for lb <= value < ub.
24 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
25  Value lb, Value ub) {
26  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
27  loc, arith::CmpIPredicate::sge, value, lb);
28  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
29  loc, arith::CmpIPredicate::slt, value, ub);
30  Value inBounds =
31  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
32  return inBounds;
33 }
34 
35 struct CastOpInterface
36  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
37  CastOp> {
38  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
39  Location loc) const {
40  auto castOp = cast<CastOp>(op);
41  auto srcType = cast<TensorType>(castOp.getSource().getType());
42 
43  // Nothing to check if the result is an unranked tensor.
44  auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
45  if (!resultType)
46  return;
47 
48  if (isa<UnrankedTensorType>(srcType)) {
49  // Check rank.
50  Value srcRank = RankOp::create(builder, loc, castOp.getSource());
51  Value resultRank =
52  arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
53  Value isSameRank = arith::CmpIOp::create(
54  builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
55  cf::AssertOp::create(builder, loc, isSameRank,
56  RuntimeVerifiableOpInterface::generateErrorMessage(
57  op, "rank mismatch"));
58  }
59 
60  // Check dimension sizes.
61  for (const auto &it : llvm::enumerate(resultType.getShape())) {
62  // Static dim size -> static/dynamic dim size does not need verification.
63  if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
64  if (!rankedSrcType.isDynamicDim(it.index()))
65  continue;
66 
67  // Static/dynamic dim size -> dynamic dim size does not need verification.
68  if (resultType.isDynamicDim(it.index()))
69  continue;
70 
71  Value srcDimSz =
72  DimOp::create(builder, loc, castOp.getSource(), it.index());
73  Value resultDimSz =
74  arith::ConstantIndexOp::create(builder, loc, it.value());
75  Value isSameSz = arith::CmpIOp::create(
76  builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
77  cf::AssertOp::create(
78  builder, loc, isSameSz,
79  RuntimeVerifiableOpInterface::generateErrorMessage(
80  op, "size mismatch of dim " + std::to_string(it.index())));
81  }
82  }
83 };
84 
85 struct DimOpInterface
86  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
87  DimOp> {
88  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
89  Location loc) const {
90  auto dimOp = cast<DimOp>(op);
91  Value rank = RankOp::create(builder, loc, dimOp.getSource());
92  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
93  cf::AssertOp::create(
94  builder, loc,
95  generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
96  RuntimeVerifiableOpInterface::generateErrorMessage(
97  op, "index is out of bounds"));
98  }
99 };
100 
101 /// Verifies that the indices on extract/insert ops are in-bounds of the
102 /// tensor's index space: 0 <= index#i < dim#i
103 template <typename OpTy>
104 struct ExtractInsertOpInterface
105  : public RuntimeVerifiableOpInterface::ExternalModel<
106  ExtractInsertOpInterface<OpTy>, OpTy> {
107  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
108  Location loc) const {
109  auto extractInsertOp = cast<OpTy>(op);
110 
111  Value tensor;
112  if constexpr (std::is_same_v<OpTy, ExtractOp>) {
113  tensor = extractInsertOp.getTensor();
114  } else if constexpr (std::is_same_v<OpTy, InsertOp>) {
115  tensor = extractInsertOp.getDest();
116  } else {
117  llvm_unreachable("invalid op");
118  }
119  auto tensorType = cast<RankedTensorType>(tensor.getType());
120  auto rank = tensorType.getRank();
121  if (rank == 0) {
122  // Nothing to check for 0-d tensors.
123  return;
124  }
125 
126  auto indices = extractInsertOp.getIndices();
127  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
128  Value assertCond;
129  for (auto i : llvm::seq<int64_t>(0, rank)) {
130  Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
131  Value inBounds =
132  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
133  assertCond =
134  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
135  : inBounds;
136  }
137  cf::AssertOp::create(builder, loc, assertCond,
138  RuntimeVerifiableOpInterface::generateErrorMessage(
139  op, "out-of-bounds access"));
140  }
141 };
142 
143 struct ExtractSliceOpInterface
144  : public RuntimeVerifiableOpInterface::ExternalModel<
145  ExtractSliceOpInterface, ExtractSliceOp> {
146  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
147  Location loc) const {
148  auto extractSliceOp = cast<ExtractSliceOp>(op);
149  RankedTensorType sourceType = extractSliceOp.getSource().getType();
150 
151  // For each dimension, assert that:
152  // 0 <= offset < dim_size
153  // 0 <= offset + (size - 1) * stride < dim_size
154  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
155  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
156  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
158  builder, loc, extractSliceOp.getMixedOffsets()[i]);
160  builder, loc, extractSliceOp.getMixedSizes()[i]);
162  builder, loc, extractSliceOp.getMixedStrides()[i]);
163 
164  // Verify that offset is in-bounds.
165  Value dimSize = builder.createOrFold<tensor::DimOp>(
166  loc, extractSliceOp.getSource(), i);
167  Value offsetInBounds =
168  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
169  cf::AssertOp::create(
170  builder, loc, offsetInBounds,
171  RuntimeVerifiableOpInterface::generateErrorMessage(
172  op, "offset " + std::to_string(i) + " is out-of-bounds"));
173 
174  // Verify that slice does not run out-of-bounds.
175  Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
176  Value sizeMinusOneTimesStride =
177  arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
178  Value lastPos =
179  arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
180  Value lastPosInBounds =
181  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
182  cf::AssertOp::create(
183  builder, loc, lastPosInBounds,
184  RuntimeVerifiableOpInterface::generateErrorMessage(
185  op, "extract_slice runs out-of-bounds along dimension " +
186  std::to_string(i)));
187  }
188  }
189 };
190 } // namespace
191 } // namespace tensor
192 } // namespace mlir
193 
195  DialectRegistry &registry) {
196  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
197  CastOp::attachInterface<CastOpInterface>(*ctx);
198  DimOp::attachInterface<DimOpInterface>(*ctx);
199  ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
200  ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
201  InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
202 
203  // Load additional dialects of which ops may get created.
204  ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
205  });
206 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111