29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
42 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
66 return (isa<tosa::IfOp>(dest->getParentOp()) ||
67 isa<tosa::WhileOp>(dest->getParentOp()));
73 TosaDialectBytecodeInterface(
Dialect *dialect)
85 return ::writeAttribute(attr, writer);
97 return ::writeType(type, writer);
104 std::unique_ptr<DialectVersion>
107 reader.
emitError(
"Dialect does not support versioning");
130 void TosaDialect::initialize() {
133 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
136 #define GET_ATTRDEF_LIST
137 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
139 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
140 declarePromisedInterfaces<
141 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
142 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp,
143 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
144 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
145 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
146 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
147 GreaterEqualOp, MatMulOp>();
154 if (llvm::isa<ElementsAttr>(value))
155 return builder.
create<tosa::ConstOp>(loc, type,
156 llvm::cast<ElementsAttr>(value));
169 <<
"expected attribute";
171 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188 bool needsSpace =
false;
189 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
190 if (!typedAttr || typedAttr.getType() != type.getValue()) {
208 if (!shapedType.hasRank())
211 auto rank = shapedType.getRank();
213 for (
int i = 0; i < rank; i++) {
214 if (shapedType.isDynamicDim(i))
216 if (shapedType.getDimSize(i) == 0)
225 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
226 auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
230 op.
emitOpError(
"expect a ranked tensor for input, got ") << op.getInput();
234 op.
emitOpError(
"expect a ranked tensor for weight, got ") << op.getWeight();
239 return op.
emitOpError() <<
"tensor has a dimension with size zero. Each "
240 "dimension of a tensor must have size >= 1";
242 auto inputEType = inputType.getElementType();
243 auto weightEType = weightType.getElementType();
245 bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
246 bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
249 if (inputIsQuant != weightIsQuant) {
251 "expect both input and weight to be float or not together, got ")
252 << inputEType <<
" and " << weightEType;
258 if ((inputIsQuant && !op.getQuantizationInfo()) ||
259 (!inputIsQuant && op.getQuantizationInfo())) {
260 op.
emitOpError(
"quantizationattr is required for quantized type, and not "
261 "allowed for float type");
270 const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
271 if (!resultETy.isIntOrIndex())
272 return emitOpError(
"result tensor is not of integer type");
275 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
276 const int64_t axis = getAxisAttr().getInt();
277 if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
278 return emitOpError(
"specified axis is outside the rank of the tensor");
284 auto inputType = llvm::cast<ShapedType>(getInput().getType());
286 return emitOpError() <<
"tensor has a dimension with size zero. Each "
287 "dimension of a tensor must have size >= 1";
289 auto inputETy = inputType.getElementType();
290 auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
293 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
294 inputETy = quantType.getStorageType();
297 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
298 resultETy = quantType.getStorageType();
300 auto accType = getAccType();
301 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
302 return emitOpError(
"accumulator type for integer tensor is not i32");
304 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
305 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
307 if (inputETy.isBF16() && !accType.isF32())
308 return emitOpError(
"accumulator type for bf16 tensor is not f32");
310 if (inputETy.isF32() && !accType.isF32())
311 return emitOpError(
"accumulator type for f32 tensor is not f32");
313 if ((inputETy.isF32() && resultETy.isF32()) ||
314 (inputETy.isF16() && resultETy.isF16()) ||
315 (inputETy.isBF16() && resultETy.isBF16()) ||
316 (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
317 (inputETy.isInteger(16) && resultETy.isInteger(16)))
320 return emitOpError(
"input/output element types are incompatible.");
325 llvm::cast<ShapedType>(getInput().getType()).getElementType();
327 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
328 inputETy = quantType.getStorageType();
330 mlir::Type maxFpType = getMaxFpAttr().getType();
331 mlir::Type minFpType = getMinFpAttr().getType();
333 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
335 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
336 outputETy = quantType.getStorageType();
340 if (inputETy != outputETy)
341 return emitOpError(
"input/output element types are incompatible.");
346 if (!inputETy.
isInteger(dataTypeBitWidth)) {
347 if (((maxFpType != minFpType) ||
350 return emitOpError(
"min/max attributes types are incompatible with "
351 "input/output element types.");
435 auto inputType = llvm::dyn_cast<ShapedType>(a.
getType());
436 assert(inputType &&
"Input must be a shaped tensor type!");
438 auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
439 inputType.getElementType());
440 assert(inputQType &&
"Tensor must have quantized datatype!");
442 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
444 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
445 assert(outputShapedType &&
"Output must be a shaped type");
447 IntegerType accElementType;
452 auto accType = outputShapedType.clone(accElementType);
465 DenseArrayAttr kernel, DenseArrayAttr stride,
466 DenseArrayAttr pad, TypeAttr accType) {
475 result.
types.push_back(outputType);
488 result.
types.push_back(outputType);
501 result.
types.push_back(outputType);
515 result.
types.push_back(outputType);
525 for (
int i = 0, e = operands.size(); i != e; ++i) {
527 if (!shape.hasRank()) {
532 outRank = std::max<int64_t>(outRank, shape.getRank());
535 outShape.resize(outRank, 1);
537 for (
int i = 0, e = operands.size(); i != e; ++i) {
539 auto rankDiff = outShape.size() - shape.getRank();
541 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
542 auto dim1 = outShape[i + rankDiff];
543 auto dim2 = shape.getDimSize(i);
544 auto resolvedDim = dim1;
548 }
else if (dim2 == 1) {
550 }
else if (dim1 != dim2) {
553 outShape[i + rankDiff] = resolvedDim;
561 MLIRContext *context, ::std::optional<Location> location,
562 ArgMaxOp::Adaptor adaptor,
565 IntegerAttr axis = adaptor.getProperties().axis;
566 int32_t axisVal = axis.getValue().getSExtValue();
568 if (!inputShape.hasRank()) {
574 outShape.reserve(inputShape.getRank() - 1);
575 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
578 outShape.push_back(inputShape.getDimSize(i));
586 MLIRContext *context, ::std::optional<Location> location,
587 RFFT2dOp::Adaptor adaptor,
591 if (!inputShape.hasRank())
595 outputShape.resize(3, ShapedType::kDynamic);
596 outputShape[0] = inputShape.getDimSize(0);
597 outputShape[1] = inputShape.getDimSize(1);
598 int64_t inWidth = inputShape.getDimSize(2);
602 if (inWidth != ShapedType::kDynamic)
603 outputShape[2] = inWidth / 2 + 1;
612 MLIRContext *context, ::std::optional<Location> location,
613 FFT2dOp::Adaptor adaptor,
615 inferredReturnShapes.push_back(
617 inferredReturnShapes.push_back(
623 MLIRContext *context, ::std::optional<Location> location,
624 ConcatOp::Adaptor adaptor,
627 const Properties &prop = adaptor.getProperties();
628 int32_t axis = prop.axis.getValue().getSExtValue();
630 bool hasRankedInput =
false;
631 for (
auto operand : adaptor.getOperands()) {
633 if (!operandShape.hasRank())
638 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
641 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
642 if (i == axis || operandShape.isDynamicDim(i))
644 if (outputShape[i] == ShapedType::kDynamic)
645 outputShape[i] = operandShape.getDimSize(i);
646 if (outputShape[i] != operandShape.getDimSize(i))
648 "Cannot concat tensors with different sizes"
649 " on the non-axis dimension ",
653 hasRankedInput =
true;
656 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
657 if (!hasRankedInput) {
663 int64_t concatDimSize = 0;
664 for (
auto operand : adaptor.getOperands()) {
669 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
670 concatDimSize = ShapedType::kDynamic;
674 concatDimSize += operandShape.getDimSize(axis);
677 outputShape[axis] = concatDimSize;
684 MLIRContext *context, ::std::optional<Location> location,
701 if (l.size() != r.size() || l.size() != 1)
706 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
707 MLIRContext *context, ::std::optional<Location> location,
708 FullyConnectedOp::Adaptor adaptor,
711 ShapeAdaptor weightShape(adaptor.getWeight().getType());
716 outShape.resize(2, ShapedType::kDynamic);
718 if (inputShape.hasRank()) {
719 outShape[0] = inputShape.getDimSize(0);
722 if (weightShape.hasRank()) {
723 outShape[1] = weightShape.getDimSize(0);
726 if (biasShape.hasRank()) {
727 outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
738 MLIRContext *context, ::std::optional<Location> location,
739 MatMulOp::Adaptor adaptor,
746 outShape.resize(3, ShapedType::kDynamic);
748 if (lhsShape.hasRank()) {
749 outShape[0] = lhsShape.getDimSize(0);
750 outShape[1] = lhsShape.getDimSize(1);
753 if (rhsShape.hasRank()) {
754 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
756 outShape[2] = rhsShape.getDimSize(2);
764 MLIRContext *context, ::std::optional<Location> location,
765 PadOp::Adaptor adaptor,
768 ShapeAdaptor paddingShape(adaptor.getPadding().getType());
773 if (!inputShape.hasRank() && !paddingShape.hasRank()) {
780 if (!inputShape.hasRank()) {
781 if (paddingShape.isDynamicDim(0)) {
786 outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
794 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
800 for (
auto val : paddings) {
801 paddingValues.push_back(val.getSExtValue());
804 outputShape.reserve(inputShape.getRank());
805 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
806 if (inputShape.isDynamicDim(i)) {
807 outputShape.push_back(ShapedType::kDynamic);
811 outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
812 paddingValues[i * 2 + 1]);
820 return to_vector(llvm::map_range(shape, [](int64_t dim) {
821 return dim == -1 ? ShapedType::kDynamic : dim;
826 MLIRContext *context, ::std::optional<Location> location,
827 SliceOp::Adaptor adaptor,
829 inferredReturnShapes.push_back(
835 auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
839 if (
static_cast<size_t>(inputType.getRank()) != getStart().size())
841 "length of start attribute is not equal rank of input shape");
843 if (
static_cast<size_t>(inputType.getRank()) != getSize().size())
845 "length of size attribute is not equal rank of input shape");
851 MLIRContext *context, ::std::optional<Location> location,
852 TableOp::Adaptor adaptor,
856 if (!inputShape.hasRank()) {
861 inferredReturnShapes.resize(1);
862 inputShape.getDims(inferredReturnShapes[0]);
867 MLIRContext *context, ::std::optional<Location> location,
868 TileOp::Adaptor adaptor,
873 if (!inputShape.hasRank()) {
874 outputShape.resize(multiples.size(), ShapedType::kDynamic);
877 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
881 outputShape.reserve(multiples.size());
882 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
883 int64_t dim = inputShape.getDimSize(i);
884 if (dim != ShapedType::kDynamic)
886 outputShape.push_back(dim);
894 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
895 ShapedType outputType = llvm::cast<ShapedType>(getType());
896 auto multiples = getMultiples();
898 if (inputType.hasRank()) {
899 if (
static_cast<size_t>(inputType.getRank()) != multiples.size())
900 return emitOpError(
"expect 'multiples' array to have length ")
901 << inputType.getRank() <<
" but got " << multiples.size() <<
".";
902 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
903 return emitOpError(
"expect same input and output tensor rank.");
904 }
else if (outputType.hasRank() &&
905 static_cast<size_t>(outputType.getRank()) != multiples.size())
906 return emitOpError(
"expect 'multiples' array to have length ")
907 << outputType.getRank() <<
" but got " << multiples.size() <<
".";
913 if (l.size() != r.size() || l.size() != 1)
919 MLIRContext *context, ::std::optional<Location> location,
920 ReshapeOp::Adaptor adaptor,
929 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
930 inferredReturnShapes.push_back(
938 int64_t numElements = inputShape.getNumElements();
939 int64_t staticMul = 1;
940 for (
auto val : newShapeValue) {
941 if (!ShapedType::isDynamic(val)) {
947 for (
auto &val : newShapeValue) {
948 if (ShapedType::isDynamic(val))
949 val = numElements / staticMul;
952 inferredReturnShapes.push_back(
959 RankedTensorType outputType = getType();
962 return emitOpError() <<
"tensor has a dimension with size zero. Each "
963 "dimension of a tensor must have size >= 1";
965 if ((int64_t) getNewShape().size() != outputType.getRank())
966 return emitOpError() <<
"new shape does not match result rank";
968 for (
auto [newShapeDim, outputShapeDim] :
969 zip(getNewShape(), outputType.getShape()))
970 if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
971 newShapeDim != outputShapeDim)
972 return emitOpError() <<
"new shape is inconsistent with result shape";
974 if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
975 int64_t inputElementsNum = inputType.getNumElements();
976 int64_t outputElementsNum = outputType.getNumElements();
977 if (inputElementsNum != outputElementsNum) {
978 return emitOpError() <<
"cannot reshape " << inputElementsNum
979 <<
" elements into " << outputElementsNum;
983 int missingDims = llvm::count(getNewShape(), -1);
985 return emitOpError() <<
"expected at most one target dimension to be -1";
997 perms = llvm::to_vector(
998 llvm::map_range(permsAttr.getValues<APInt>(),
999 [](
const APInt &val) { return val.getSExtValue(); }));
1005 MLIRContext *context, ::std::optional<Location> location,
1006 TransposeOp::Adaptor adaptor,
1008 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1012 if (permsShape.hasRank() && permsShape.getRank() == 0)
1017 if (!inputShape.hasRank() || !permsShape.hasRank() ||
1018 permsShape.isDynamicDim(0)) {
1025 if (permsShape.getDimSize(0) != inputShape.getRank()) {
1031 if (inputShape.getRank() == 0) {
1037 bool allTheSame =
true;
1038 for (
int i = 1, s = inputShape.getRank(); i < s; i++) {
1039 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1048 outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1053 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1058 attr.getType().getRank() == 1) {
1061 if (inputShape.getRank() != permShape.
getRank())
1063 "constant permutation must be the same length"
1064 " as the input rank");
1067 for (
int i = 0, e = inputShape.getRank(); i < e; i++) {
1068 if (inputShape.getRank() <= permShape.
getDimSize(i))
1072 outputShape.reserve(inputShape.getRank());
1073 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1074 outputShape[i] = inputShape.getDimSize(permShape.
getDimSize(i));
1083 TensorType inputType = getInput1().getType();
1085 TensorType outputType = getOutput().getType();
1087 if (permType.
hasRank() && permType.getRank() != 1)
1088 return emitOpError()
1089 <<
"expected permutation tensor to be rank 1 but got rank "
1090 << permType.getRank();
1092 if (!permType.isDynamicDim(0) &&
1093 permType.getDimSize(0) != inputType.getRank())
1094 return emitOpError() <<
"expected permutation tensor dim 0 to have size "
1095 << inputType.getRank()
1096 <<
" (input rank) but got size "
1097 << permType.getDimSize(0);
1099 inputType.getRank() != outputType.getRank())
1100 return emitOpError()
1101 <<
"expected input tensor rank to equal result tensor rank";
1103 if (!permType.isDynamicDim(0) &&
1104 permType.getDimSize(0) != outputType.getRank())
1105 return emitOpError() <<
"expected permutation tensor dim 0 to have size "
1106 << outputType.getRank()
1107 <<
" (output rank) but got size "
1108 << permType.getDimSize(0);
1111 if (
succeeded(getConstantPerms(constantPerms))) {
1115 "Unexpectedly found permutation tensor without rank");
1117 return emitOpError() <<
"expected valid permutation tensor";
1123 MLIRContext *context, ::std::optional<Location> location,
1124 GatherOp::Adaptor adaptor,
1127 outputShape.resize(3, ShapedType::kDynamic);
1129 ShapeAdaptor valuesShape(adaptor.getValues().getType());
1130 if (valuesShape.hasRank()) {
1131 outputShape[0] = valuesShape.getDimSize(0);
1132 outputShape[2] = valuesShape.getDimSize(2);
1135 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1136 if (indicesShape.hasRank()) {
1137 if (outputShape[0] == ShapedType::kDynamic)
1138 outputShape[0] = indicesShape.getDimSize(0);
1139 if (outputShape[1] == ShapedType::kDynamic)
1140 outputShape[1] = indicesShape.getDimSize(1);
1148 MLIRContext *context, ::std::optional<Location> location,
1149 ResizeOp::Adaptor adaptor,
1152 outputShape.resize(4, ShapedType::kDynamic);
1155 if (!inputShape.hasRank())
1158 outputShape[0] = inputShape.getDimSize(0);
1159 outputShape[3] = inputShape.getDimSize(3);
1160 int64_t inputHeight = inputShape.getDimSize(1);
1161 int64_t inputWidth = inputShape.getDimSize(2);
1163 if ((inputHeight == ShapedType::kDynamic) ||
1164 (inputWidth == ShapedType::kDynamic))
1173 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1178 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1187 MLIRContext *context, ::std::optional<Location> location,
1188 ScatterOp::Adaptor adaptor,
1191 outputShape.resize(3, ShapedType::kDynamic);
1193 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1194 if (valuesInShape.hasRank()) {
1195 outputShape[0] = valuesInShape.getDimSize(0);
1196 outputShape[1] = valuesInShape.getDimSize(1);
1197 outputShape[2] = valuesInShape.getDimSize(2);
1200 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1201 if (indicesShape.hasRank()) {
1202 if (outputShape[0] == ShapedType::kDynamic)
1203 outputShape[0] = indicesShape.getDimSize(0);
1207 if (inputShape.hasRank()) {
1208 if (outputShape[0] == ShapedType::kDynamic)
1209 outputShape[0] = inputShape.getDimSize(0);
1210 if (outputShape[2] == ShapedType::kDynamic)
1211 outputShape[2] = inputShape.getDimSize(2);
1221 int64_t axisVal = axis.getValue().getSExtValue();
1222 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
1228 operandShape.
getDims(outputShape);
1229 outputShape[axisVal] = 1;
1234 #define COMPATIBLE_RETURN_TYPES(OP) \
1235 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1236 if (l.size() != r.size() || l.size() != 1) \
1238 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1240 return succeeded(verifyCompatibleShape(l[0], r[0])); \
1243 #define REDUCE_SHAPE_INFER(OP) \
1244 LogicalResult OP::inferReturnTypeComponents( \
1245 MLIRContext *context, ::std::optional<Location> location, \
1246 OP::Adaptor adaptor, \
1247 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1249 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1250 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1251 const Properties &prop = adaptor.getProperties(); \
1252 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1253 inferredReturnShapes); \
1255 COMPATIBLE_RETURN_TYPES(OP)
1263 #undef REDUCE_SHAPE_INFER
1265 #undef COMPATIBLE_RETURN_TYPES
1267 template <
typename T>
1270 TensorType inputType = op.getInput().getType();
1271 TensorType outputType = op.getOutput().getType();
1272 int32_t reduceAxis = op.getAxis();
1274 if (reduceAxis < 0) {
1275 op.
emitOpError(
"reduce axis must not be negative");
1279 int64_t inputRank = inputType.getRank();
1282 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1284 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
1290 int64_t outputRank = outputType.getRank();
1291 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
1293 "expect output tensor rank to be equal to input tensor rank");
1296 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1298 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
1304 if (outputRank != 0) {
1305 auto outputShape = outputType.
getShape();
1306 if (!outputType.isDynamicDim(reduceAxis) &&
1307 outputShape[reduceAxis] != 1) {
1308 op.
emitOpError(
"expect reduced dimension size to be 1, got ")
1309 << outputShape[reduceAxis];
1336 #define NARY_SHAPE_INFER(OP) \
1337 LogicalResult OP::inferReturnTypeComponents( \
1338 MLIRContext *context, ::std::optional<Location> location, \
1339 ValueShapeRange operands, DictionaryAttr attributes, \
1340 OpaqueProperties properties, RegionRange regions, \
1341 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1342 return NAryInferReturnTypes(operands, inferredReturnShapes); \
1385 #undef PRED_SHAPE_INFER
1392 outputShape.resize(4, ShapedType::kDynamic);
1407 if (!ShapedType::isDynamic(height)) {
1408 int64_t padded = height + pad[0] + pad[1] - kernel[0];
1409 outputShape[1] = padded / stride[0] + 1;
1412 if (!ShapedType::isDynamic(width)) {
1413 int64_t padded = width + pad[2] + pad[3] - kernel[1];
1414 outputShape[2] = padded / stride[1] + 1;
1422 MLIRContext *context, ::std::optional<Location> location,
1423 Conv2DOp::Adaptor adaptor,
1427 int64_t inputWidth = ShapedType::kDynamic;
1428 int64_t inputHeight = ShapedType::kDynamic;
1429 int64_t weightWidth = ShapedType::kDynamic;
1430 int64_t weightHeight = ShapedType::kDynamic;
1435 if (inputShape.hasRank()) {
1436 outputShape[0] = inputShape.getDimSize(0);
1437 inputHeight = inputShape.getDimSize(1);
1438 inputWidth = inputShape.getDimSize(2);
1442 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1443 if (weightShape.hasRank()) {
1444 outputShape[3] = weightShape.getDimSize(0);
1445 weightHeight = weightShape.getDimSize(1);
1446 weightWidth = weightShape.getDimSize(2);
1451 if (biasShape.hasRank()) {
1452 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1453 ? biasShape.getDimSize(0)
1461 if (!ShapedType::isDynamic(inputHeight) &&
1462 !ShapedType::isDynamic(weightHeight)) {
1463 int64_t inputSize = inputHeight + padding[0] + padding[1];
1464 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1465 int64_t unstridedResult = inputSize - filterSize + 1;
1466 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1469 if (!ShapedType::isDynamic(inputWidth) &&
1470 !ShapedType::isDynamic(weightWidth)) {
1471 int64_t inputSize = inputWidth + padding[2] + padding[3];
1472 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1473 int64_t unstridedResult = inputSize - filterSize + 1;
1474 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1484 MLIRContext *context, ::std::optional<Location> location,
1485 Conv3DOp::Adaptor adaptor,
1489 int64_t inputWidth = ShapedType::kDynamic;
1490 int64_t inputHeight = ShapedType::kDynamic;
1491 int64_t inputDepth = ShapedType::kDynamic;
1493 int64_t weightWidth = ShapedType::kDynamic;
1494 int64_t weightHeight = ShapedType::kDynamic;
1495 int64_t weightDepth = ShapedType::kDynamic;
1499 if (inputShape.hasRank()) {
1500 outputShape[0] = inputShape.getDimSize(0);
1501 inputDepth = inputShape.getDimSize(1);
1502 inputHeight = inputShape.getDimSize(2);
1503 inputWidth = inputShape.getDimSize(3);
1507 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1508 if (weightShape.hasRank()) {
1509 outputShape[4] = weightShape.getDimSize(0);
1510 weightDepth = weightShape.getDimSize(1);
1511 weightHeight = weightShape.getDimSize(2);
1512 weightWidth = weightShape.getDimSize(3);
1517 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1518 outputShape[4] = biasShape.getDimSize(0);
1525 if (!ShapedType::isDynamic(inputDepth) &&
1526 !ShapedType::isDynamic(weightDepth)) {
1527 int32_t inputSize = inputDepth + pad[0] + pad[1];
1528 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1529 int32_t unstridedResult = inputSize - filterSize + 1;
1530 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1533 if (!ShapedType::isDynamic(inputHeight) &&
1534 !ShapedType::isDynamic(weightHeight)) {
1535 int32_t inputSize = inputHeight + pad[2] + pad[3];
1536 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1537 int32_t unstridedResult = inputSize - filterSize + 1;
1538 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1541 if (!ShapedType::isDynamic(inputWidth) &&
1542 !ShapedType::isDynamic(weightWidth)) {
1543 int32_t inputSize = inputWidth + pad[4] + pad[5];
1544 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1545 int32_t unstridedResult = inputSize - filterSize + 1;
1546 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1556 MLIRContext *context, ::std::optional<Location> location,
1557 AvgPool2dOp::Adaptor adaptor,
1560 const Properties &prop = adaptor.getProperties();
1562 inferredReturnShapes);
1566 MLIRContext *context, ::std::optional<Location> location,
1567 MaxPool2dOp::Adaptor adaptor,
1570 const Properties &prop = adaptor.getProperties();
1572 inferredReturnShapes);
1576 MLIRContext *context, ::std::optional<Location> location,
1577 DepthwiseConv2DOp::Adaptor adaptor,
1581 int64_t inputWidth = ShapedType::kDynamic;
1582 int64_t inputHeight = ShapedType::kDynamic;
1583 int64_t inputChannels = ShapedType::kDynamic;
1585 int64_t weightWidth = ShapedType::kDynamic;
1586 int64_t weightHeight = ShapedType::kDynamic;
1587 int64_t depthChannels = ShapedType::kDynamic;
1591 if (inputShape.hasRank()) {
1592 outputShape[0] = inputShape.getDimSize(0);
1593 inputHeight = inputShape.getDimSize(1);
1594 inputWidth = inputShape.getDimSize(2);
1595 inputChannels = inputShape.getDimSize(3);
1599 ShapeAdaptor weightShape(adaptor.getWeight().getType());
1600 if (weightShape.hasRank()) {
1601 weightHeight = weightShape.getDimSize(0);
1602 weightWidth = weightShape.getDimSize(1);
1603 inputChannels = ShapedType::isDynamic(inputChannels)
1604 ? weightShape.getDimSize(2)
1606 depthChannels = weightShape.getDimSize(3);
1611 if (!ShapedType::isDynamic(inputChannels) &&
1612 !ShapedType::isDynamic(depthChannels)) {
1613 outputShape[3] = inputChannels * depthChannels;
1618 if (biasShape.hasRank()) {
1619 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1620 ? biasShape.getDimSize(0)
1628 if (!ShapedType::isDynamic(inputHeight) &&
1629 !ShapedType::isDynamic(weightHeight)) {
1630 int64_t inputSize = inputHeight + padding[0] + padding[1];
1631 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1632 int64_t unstridedResult = inputSize - filterSize + 1;
1633 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1636 if (!ShapedType::isDynamic(inputWidth) &&
1637 !ShapedType::isDynamic(weightWidth)) {
1638 int64_t inputSize = inputWidth + padding[2] + padding[3];
1639 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1640 int64_t unstridedResult = inputSize - filterSize + 1;
1641 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1651 MLIRContext *context, ::std::optional<Location> location,
1652 TransposeConv2DOp::Adaptor adaptor,
1658 int64_t inputWidth = ShapedType::kDynamic;
1659 int64_t inputHeight = ShapedType::kDynamic;
1660 int64_t weightWidth = ShapedType::kDynamic;
1661 int64_t weightHeight = ShapedType::kDynamic;
1665 if (inputShape.hasRank()) {
1666 outputShape[0] = ShapedType::isDynamic(outputShape[0])
1667 ? inputShape.getDimSize(0)
1669 inputHeight = inputShape.getDimSize(1);
1670 inputWidth = inputShape.getDimSize(2);
1674 ShapeAdaptor weightShape(adaptor.getFilter().getType());
1675 if (weightShape.hasRank()) {
1676 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1677 ? weightShape.getDimSize(0)
1679 weightHeight = weightShape.getDimSize(1);
1680 weightWidth = weightShape.getDimSize(2);
1685 if (biasShape.hasRank()) {
1686 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1687 ? biasShape.getDimSize(0)
1694 if (!ShapedType::isDynamic(inputHeight) &&
1695 !ShapedType::isDynamic(weightHeight)) {
1696 int64_t calculateSize =
1697 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1699 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1702 if (!ShapedType::isDynamic(inputWidth) &&
1703 !ShapedType::isDynamic(weightWidth)) {
1704 int64_t calculateSize =
1705 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1707 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1715 MLIRContext *context, ::std::optional<Location> location,
1716 IfOp::Adaptor adaptor,
1719 for (
Region *region : adaptor.getRegions()) {
1720 for (
auto &block : *region)
1721 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1722 yieldOps.push_back(returnOp);
1725 if (yieldOps.empty())
1730 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1731 for (
auto operand : yieldOps.front().getOperands()) {
1732 resultKnowledge.push_back(
1736 for (
auto yieldOp : yieldOps) {
1737 if (resultKnowledge.size() != yieldOp.getNumOperands())
1741 int32_t index = it.index();
1743 resultKnowledge[index],
1747 resultKnowledge[index] = meet;
1752 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1759 MLIRContext *context, ::std::optional<Location> location,
1760 WhileOp::Adaptor adaptor,
1763 for (
auto &block : adaptor.getBody())
1764 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1765 yieldOps.push_back(returnOp);
1769 if (yieldOps.empty())
1774 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1775 for (
auto operand : yieldOps.front().getOperands()) {
1776 resultKnowledge.push_back(
1780 for (
auto yieldOp : yieldOps) {
1781 if (resultKnowledge.size() != yieldOp.getNumOperands())
1785 int32_t index = it.index();
1787 resultKnowledge[index],
1789 resultKnowledge[index] = meet;
1795 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1801 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1802 if (
auto vt = llvm::dyn_cast<VectorType>(getType()))
1803 return llvm::to_vector<4>(vt.getShape());
1804 return std::nullopt;
1841 bool printBlockTerminators =
false;
1843 p <<
" " << getCond();
1844 if (!getResults().empty()) {
1845 p <<
" -> (" << getResultTypes() <<
")";
1847 printBlockTerminators =
true;
1852 printBlockTerminators);
1855 auto &elseRegion = getElseBranch();
1856 if (!elseRegion.
empty()) {
1860 printBlockTerminators);
1868 TensorType outputType = getOutput().getType();
1869 int32_t reverseAxis = getAxis();
1871 if (reverseAxis < 0)
1872 return emitOpError(
"expected non-negative reverse axis");
1874 int64_t inputRank = inputType.getRank();
1877 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1878 return emitOpError(
"expect input tensor rank (")
1879 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
1883 int64_t outputRank = outputType.getRank();
1884 if (inputType.
hasRank() && outputRank != inputType.getRank())
1886 "expect output tensor rank to be equal to input tensor rank");
1887 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1888 return emitOpError(
"expect output tensor rank (")
1889 << outputRank <<
") to be larger than reverse axis ("
1890 << reverseAxis <<
")";
1907 FunctionType functionType;
1912 result.
addTypes(functionType.getResults());
1914 if (functionType.getNumInputs() != operands.size()) {
1916 <<
"expected as many input types as operands "
1917 <<
"(expected " << operands.size() <<
" got "
1918 << functionType.getNumInputs() <<
")";
1928 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
1929 regionArgs[i].type = functionType.getInput(i);
1939 StringRef prefix =
"") {
1940 assert(blocksArgs.size() == initializers.size() &&
1941 "expected same length of arguments and initializers");
1942 if (initializers.empty())
1945 parser << prefix <<
'(';
1946 llvm::interleaveComma(
1947 llvm::zip(blocksArgs, initializers), parser,
1948 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
1968 #define GET_ATTRDEF_CLASSES
1969 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1975 #define GET_OP_CLASSES
1976 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define REDUCE_SHAPE_INFER(OP)
static bool hasZeroDimension(ShapedType shapedType)
static LogicalResult verifyConvOp(T op)
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
IntegerType getIntegerType(unsigned width)
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)