MLIR  19.0.0git
BuiltinAttributeInterfaces.cpp
Go to the documentation of this file.
1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "llvm/ADT/Sequence.h"
13 
14 using namespace mlir;
15 using namespace mlir::detail;
16 
17 //===----------------------------------------------------------------------===//
18 /// Tablegen Interface Definitions
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // ElementsAttr
25 //===----------------------------------------------------------------------===//
26 
27 Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
28  return elementsAttr.getShapedType().getElementType();
29 }
30 
31 int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
32  return elementsAttr.getShapedType().getNumElements();
33 }
34 
35 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
36  // Verify that the rank of the indices matches the held type.
37  int64_t rank = type.getRank();
38  if (rank == 0 && index.size() == 1 && index[0] == 0)
39  return true;
40  if (rank != static_cast<int64_t>(index.size()))
41  return false;
42 
43  // Verify that all of the indices are within the shape dimensions.
44  ArrayRef<int64_t> shape = type.getShape();
45  return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
46  int64_t dim = static_cast<int64_t>(index[i]);
47  return 0 <= dim && dim < shape[i];
48  });
49 }
50 bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
51  ArrayRef<uint64_t> index) {
52  return isValidIndex(elementsAttr.getShapedType(), index);
53 }
54 
55 uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
56  ShapedType shapeType = llvm::cast<ShapedType>(type);
57  assert(isValidIndex(shapeType, index) &&
58  "expected valid multi-dimensional index");
59 
60  // Reduce the provided multidimensional index into a flattended 1D row-major
61  // index.
62  auto rank = shapeType.getRank();
63  ArrayRef<int64_t> shape = shapeType.getShape();
64  uint64_t valueIndex = 0;
65  uint64_t dimMultiplier = 1;
66  for (int i = rank - 1; i >= 0; --i) {
67  valueIndex += index[i] * dimMultiplier;
68  dimMultiplier *= shape[i];
69  }
70  return valueIndex;
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // MemRefLayoutAttrInterface
75 //===----------------------------------------------------------------------===//
76 
80  if (m.getNumDims() != shape.size())
81  return emitError() << "memref layout mismatch between rank and affine map: "
82  << shape.size() << " != " << m.getNumDims();
83 
84  return success();
85 }
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getNumDims() const
Definition: AffineMap.cpp:378
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Detect if any of the given parameter types has a sub-element handler.
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26