MLIR 22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
24#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31#include <numeric>
32
33using namespace mlir;
34using namespace mlir::tosa;
35
36#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38
39//===----------------------------------------------------------------------===//
40// Tosa dialect interface includes.
41//===----------------------------------------------------------------------===//
42
43#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47
48namespace {
49#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// Dialect Function Inliner Interface.
53//===----------------------------------------------------------------------===//
54struct TosaInlinerInterface : public DialectInlinerInterface {
56
57 //===--------------------------------------------------------------------===//
58 // Analysis Hooks.
59 //===--------------------------------------------------------------------===//
60
61 /// All operations can be inlined by default.
62 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63 IRMapping &map) const final {
64 return true;
65 }
66
67 /// All regions with If and While parent operators can be inlined.
68 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69 IRMapping &map) const final {
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
72 }
73};
74
75/// This class implements the bytecode interface for the Tosa dialect.
76struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77 TosaDialectBytecodeInterface(Dialect *dialect)
78 : BytecodeDialectInterface(dialect) {}
79
80 //===--------------------------------------------------------------------===//
81 // Attributes
82
83 Attribute readAttribute(DialectBytecodeReader &reader) const override {
84 return ::readAttribute(getContext(), reader);
85 }
86
87 LogicalResult writeAttribute(Attribute attr,
88 DialectBytecodeWriter &writer) const override {
89 return ::writeAttribute(attr, writer);
90 }
91
92 //===--------------------------------------------------------------------===//
93 // Types
94
95 Type readType(DialectBytecodeReader &reader) const override {
96 return ::readType(getContext(), reader);
97 }
98
99 LogicalResult writeType(Type type,
100 DialectBytecodeWriter &writer) const override {
101 return ::writeType(type, writer);
102 }
103
104 void writeVersion(DialectBytecodeWriter &writer) const final {
105 // TODO: Populate.
106 }
107
108 std::unique_ptr<DialectVersion>
109 readVersion(DialectBytecodeReader &reader) const final {
110 // TODO: Populate
111 reader.emitError("Dialect does not support versioning");
112 return nullptr;
113 }
114
115 LogicalResult upgradeFromVersion(Operation *topLevelOp,
116 const DialectVersion &version) const final {
117 return success();
118 }
119};
120
121} // namespace
122
123//===----------------------------------------------------------------------===//
124// TOSA control flow support.
125//===----------------------------------------------------------------------===//
126
127/// Returns the while loop body.
128SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129 return {&getBodyGraph()};
130}
131
132//===----------------------------------------------------------------------===//
133// TOSA variable operator support.
134//===----------------------------------------------------------------------===//
135
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
139 }));
140}
141
142// returns type of variable op
143RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144 Type elementType = variableOp.getType();
145 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147 return RankedTensorType::get(shape, elementType);
148}
149
150//===----------------------------------------------------------------------===//
151// Tosa dialect initialization.
152//===----------------------------------------------------------------------===//
153
154void TosaDialect::initialize() {
155 addTypes<
156#define GET_TYPEDEF_LIST
157#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158 >();
159 addOperations<
160#define GET_OP_LIST
161#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162 >();
163 addAttributes<
164#define GET_ATTRDEF_LIST
165#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166 >();
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
176}
177
178Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
179 Type type, Location loc) {
180 // Tosa dialect constants only support ElementsAttr unlike standard dialect
181 // constant which supports all attributes.
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
185 }
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
189 return nullptr;
190}
191
192//===----------------------------------------------------------------------===//
193// Parsers and printers
194//===----------------------------------------------------------------------===//
195
196namespace {
197
198ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199 DenseElementsAttr &varShapeAttr,
200 TypeAttr &typeAttr) {
201 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
203 return parser.emitError(parser.getCurrentLocation())
204 << "expected ranked type";
205
206 auto elementType = shapedType.getElementType();
207 typeAttr = TypeAttr::get(elementType);
208 ArrayRef<int64_t> shape = shapedType.getShape();
209 Builder builder(parser.getContext());
210 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211 return success();
212 }
213 return parser.emitError(parser.getCurrentLocation())
214 << "expected shaped type";
215}
216
217} // namespace
218
219// parses the optional initial value or type for a tosa variable
220// with initial value:
221// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222//
223// without initial value:
224// tosa.variable @name : tensor<1x8xf32>
226 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227 Attribute &initialValueAttr) {
228 if (succeeded(parser.parseOptionalEqual())) {
229 if (failed(parser.parseAttribute(initialValueAttr))) {
230 return parser.emitError(parser.getCurrentLocation())
231 << "expected attribute";
232 }
233 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235 typeAttr);
236 }
237 return parser.emitError(parser.getCurrentLocation())
238 << "expected Typed attr";
239 }
240
241 initialValueAttr = nullptr;
242 Type parsedType;
243 if (failed(parser.parseColonType(parsedType))) {
244 return parser.emitError(parser.getCurrentLocation())
245 << "expected type after colon";
246 }
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248}
249
251 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252 TypeAttr typeAttr, Attribute initialValueAttr) {
253 bool needsSpace = false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255 auto shape =
256 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
259 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260 auto tensorTypeAttr = TypeAttr::get(tensorType);
261 p << ": ";
262 p.printAttribute(tensorTypeAttr);
263 needsSpace = true; // subsequent attr value needs a space separator
264 }
265 if (initialValueAttr) {
266 if (needsSpace)
267 p << ' ';
268 p << "= ";
269 p.printAttribute(initialValueAttr);
270 }
271}
272
273namespace {
274
275// parse attributes with special handling for tosa enum attributes
276template <typename EnumType>
277ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
278 NamedAttrList &outAttrs) {
279 llvm::StringRef name;
280 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
281 return failure();
282
283 // special handling: rounding_mode accepts a *bare* RoundingMode enum
284 // keyword.
285 llvm::StringRef kw;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name == "rounding_mode" &&
288 succeeded(parser.parseOptionalKeyword(&kw))) {
289 auto sym = symbolizeRoundingMode(kw);
290 if (!sym)
291 return parser.emitError(parser.getCurrentLocation())
292 << "invalid rounding_mode value: " << kw;
293 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
294 outAttrs.push_back(NamedAttribute(name, attr));
295 return success();
296 }
297 }
298 // special handling: mode accepts a *bare* ResizeMode enum keyword.
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
300 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
301 auto sym = symbolizeResizeMode(kw);
302 if (!sym)
303 return parser.emitError(parser.getCurrentLocation())
304 << "invalid resize mode value: " << kw;
305 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
306 outAttrs.push_back(NamedAttribute(name, attr));
307 return success();
308 }
309 }
310 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
311 // keyword.
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
313 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
314 auto sym = symbolizeNanPropagationMode(kw);
315 if (!sym)
316 return parser.emitError(parser.getCurrentLocation())
317 << "invalid nan_mode value: " << kw;
318 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
319 outAttrs.push_back(NamedAttribute(name, attr));
320 return success();
321 }
322 }
323
324 // special handling: block_size accepts a *bare* BlockSizeMode enum
325 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
326 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
327 auto sym = symbolizeBlockSize(kw);
328 if (!sym)
329 return parser.emitError(parser.getCurrentLocation())
330 << "invalid block_size value: " << kw;
331 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
332 outAttrs.push_back(NamedAttribute(name, attr));
333 return success();
334 }
335 }
336
337 // Default path: parse any normal attribute literal, including fully qualified
338 // enum keyword
339 Attribute attr;
340 return parser.parseAttribute(attr, name, outAttrs);
341}
342
343template <typename EnumType>
344ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
345 // parse operands
347 if (parser.parseCommaSeparatedList(
348 [&]() { return parser.parseOperand(operands.emplace_back()); }))
349 return failure();
350
351 // Parse { attr-dict } with special handling for enum bare token
352 NamedAttrList attrs;
353 if (succeeded(parser.parseOptionalLBrace()) &&
354 failed(parser.parseOptionalRBrace())) {
355 do {
356 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
357 return failure();
358 } while (succeeded(parser.parseOptionalComma()));
359 if (parser.parseRBrace())
360 return failure();
361 }
362
363 FunctionType fnTy;
364 if (parser.parseColonType(fnTy))
365 return failure();
366
367 // Resolve operands and types
368 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
369 parser.getCurrentLocation(),
370 result.operands)))
371 return failure();
372
373 result.addTypes(fnTy.getResults());
374 result.addAttributes(attrs);
375
376 return success();
377}
378
379void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
380 parser << namedAttr.getName().strref() << " = ";
381 auto attr = namedAttr.getValue();
382 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
383 parser << roundingModeAttr.getValue();
384 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
385 parser << resizeModeAttr.getValue();
386 } else if (auto nanPropagationModeAttr =
387 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
388 parser << nanPropagationModeAttr.getValue();
389 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
390 parser << blockSizeAttr.getValue();
391 } else {
392 parser.printAttribute(attr);
393 }
394}
395
396// print with special handling for default valued NanPropagationMode attribute
397void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
398 parser << " ";
399 parser.printOperands(op->getOperands());
400
401 NamedAttrList toPrint(op->getAttrs());
402 // remove default NanPropagate attribute
403 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
404 for (auto attr : op->getAttrs()) {
405 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
406 if (nanAttr.getValue() == kDefaultNanValue) {
407 // elide from toPrint
408 toPrint.erase(attr.getName());
409 break;
410 }
411 }
412 }
413
414 if (!toPrint.empty()) {
415 parser << " {";
416 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
417 printNamedAttr(parser, namedAttr);
418 });
419 parser << "}";
420 }
421
422 parser << " : ";
423 parser.printFunctionalType(op);
424}
425
426// print with special handling for enums: RoundingMode, ResizeMode
427void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
428 parser << " ";
429 parser.printOperands(op->getOperands());
430
431 if (!op->getAttrs().empty()) {
432 parser << " {";
433 llvm::interleaveComma(op->getAttrs(), parser,
434 [&](const NamedAttribute namedAttr) {
435 printNamedAttr(parser, namedAttr);
436 });
437 parser << "}";
438 }
439
440 parser << " : ";
441 parser.printFunctionalType(op);
442}
443
444} // namespace
445
446ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
447 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
448}
449
450void RescaleOp::print(OpAsmPrinter &parser) {
451 printWithEnumHandling(parser, *this);
452}
453
454ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
455 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
456}
457
458void ApplyScaleOp::print(OpAsmPrinter &parser) {
459 printWithEnumHandling(parser, *this);
460}
461
462ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
463 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
464}
465
466void ResizeOp::print(OpAsmPrinter &parser) {
467 printWithEnumHandling(parser, *this);
468}
469
470ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
471 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
472}
473
474void ArgMaxOp::print(OpAsmPrinter &parser) {
475 printWithNanPropagationHandling(parser, *this);
476}
477
478ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
479 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
480}
481
482void MaxPool2dOp::print(OpAsmPrinter &parser) {
483 printWithNanPropagationHandling(parser, *this);
484}
485
486ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
487 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
488}
489
490void ClampOp::print(OpAsmPrinter &parser) {
491 printWithNanPropagationHandling(parser, *this);
492}
493
494ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
495 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
496}
497
498void MaximumOp::print(OpAsmPrinter &parser) {
499 printWithNanPropagationHandling(parser, *this);
500}
501
502ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
503 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
504}
505
506void MinimumOp::print(OpAsmPrinter &parser) {
507 printWithNanPropagationHandling(parser, *this);
508}
509
510ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
511 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
512}
513
514void ReduceMaxOp::print(OpAsmPrinter &parser) {
515 printWithNanPropagationHandling(parser, *this);
516}
517
518ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
519 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
520}
521
522void ReduceMinOp::print(OpAsmPrinter &parser) {
523 printWithNanPropagationHandling(parser, *this);
524}
525
526ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
528 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
529}
530
531void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
532 printWithEnumHandling(parser, *this);
533}
534
535ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
537 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
538}
539
540void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
541 printWithEnumHandling(parser, *this);
542}
543
544ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
546 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
547}
548
549void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
550 printWithEnumHandling(parser, *this);
551}
552
553//===----------------------------------------------------------------------===//
554// Tosa utilities.
555//===----------------------------------------------------------------------===//
556
557static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
558 if (lhs % rhs != 0)
559 return std::nullopt;
560 return lhs / rhs;
561}
562
564 auto srcType = getElementTypeOrSelf(type);
565 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
566 srcType = quantType.getStorageType();
567 return srcType;
568}
569
573
574static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
575 Value valZp, StringRef name) {
577 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
578
579 bool bothInts =
580 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
581 bool sameBitWidth =
582 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
583
584 if (!bothInts || !sameBitWidth) {
585 return op->emitOpError()
586 << "expected " << name << " and " << name
587 << "_zp to both be integer of the same bitwidth, but got " << eType
588 << " vs. " << eZpType;
589 }
590 return success();
591}
592
593// Create a pad-const const tensor with value of `val` of required data-type
595 Value src, int32_t val) {
596 const auto srcType = getElementTypeOrSelf(src);
597 const auto srcElemType = getStorageElementTypeOrSelf(src);
598 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
599 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
600 const auto padConstAttr{
601 llvm::isa<FloatType>(srcElemType)
602 ? DenseElementsAttr::get(padConstEType,
603 builder.getFloatAttr(srcElemType, val))
604 : DenseElementsAttr::get(padConstEType,
605 builder.getIntegerAttr(srcElemType, val))};
606 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
607}
608
610 if (dyn_cast<tosa::mxint8Type>(type))
611 return 8;
612 return type.getIntOrFloatBitWidth();
613}
614
615//===----------------------------------------------------------------------===//
616// TOSA Operator Verifiers.
617//===----------------------------------------------------------------------===//
618
619template <typename T>
620static LogicalResult verifyConvOp(T op) {
621 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
622 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
623
624 auto inputEType = inputType.getElementType();
625 auto weightEType = weightType.getElementType();
626 auto biasEType =
627 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
628 auto resultEType =
629 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
630 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
631 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
632
633 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
634 inputEType = quantType.getStorageType();
635
636 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
637 weightEType = quantType.getStorageType();
638
639 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
640 biasEType = quantType.getStorageType();
641
642 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
643 resultEType = quantType.getStorageType();
644
645 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
646 // for now, only enforce bias element type == result element type for
647 // float types.
648 op.emitOpError(
649 "expect both bias and result to have same element type, got ")
650 << biasEType << " and " << resultEType;
651 return failure();
652 }
653
654 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
655 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
656 if (inputEType != weightEType) {
657 op.emitOpError(
658 "expect both input and weight to have same element type, got ")
659 << inputEType << " and " << weightEType;
660 return failure();
661 }
662 }
663
664 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
665 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
666
667 // Either both must be float or both non-float.
668 if (inputIsFloat != weightIsFloat) {
669 op.emitOpError(
670 "expect both input and weight to be float or not together, got ")
671 << inputEType << " and " << weightEType;
672 return failure();
673 }
674
675 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
676 if (inputEType != inputZpEType) {
677 return op.emitOpError("expect both input and its zero point are the same "
678 "element type, got ")
679 << inputEType << " and " << inputZpEType;
680 }
681
682 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
683 if (weightEType != weightZpEType) {
684 return op.emitOpError("expect both weight and its zero point are the same "
685 "element type, got ")
686 << weightEType << " and " << weightZpEType;
687 }
688
689 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
690 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
691 return failure();
692
693 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
694 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
695 return failure();
696
697 return success();
698}
699
700LogicalResult tosa::ConstOp::verify() {
701
702 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
703 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
704
705 if (!attrType || !outputType) {
706 emitOpError("expected tensors for attr/result type");
707 return failure();
708 }
709
710 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
711 outputType.getElementType())) {
712 if (result.getStorageType() == attrType.getElementType())
713 return success();
714 }
715
716 if (attrType.getElementType() != outputType.getElementType()) {
717 emitOpError("expected same attr/result element types");
718 return failure();
719 }
720
721 return success();
722}
723
724template <typename T>
725static LogicalResult verifyConvOpModes(T op) {
726 auto inputEType =
727 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
728
729 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
730 inputEType = quantType.getStorageType();
731
732 auto accType = op.getAccType();
733 if (inputEType.isInteger(8) && !accType.isInteger(32))
734 return op.emitOpError("accumulator type for i8 tensor is not i32");
735
736 if (inputEType.isInteger(16) && !accType.isInteger(48))
737 return op.emitOpError("accumulator type for i16 tensor is not i48");
738
739 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
740 return op.emitOpError("accumulator type for f8 tensor is not f16");
741
742 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
743 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
744
745 if (inputEType.isBF16() && !accType.isF32())
746 return op.emitOpError("accumulator type for bf16 tensor is not f32");
747
748 if (inputEType.isF32() && !accType.isF32())
749 return op.emitOpError("accumulator type for f32 tensor is not f32");
750
751 auto resultEType =
752 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
753
754 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
755 resultEType = quantType.getStorageType();
756
757 return success();
758}
759
760//===----------------------------------------------------------------------===//
761// ERROR_IF functions.
762// ERROR_IF is a predicate that must set an error if the condition holds.
763//===----------------------------------------------------------------------===//
764
765template <typename T>
766static LogicalResult verifyConvOpErrorIf(T op) {
767 llvm::ArrayRef<int64_t> padding = op.getPad();
768 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
769 return op.emitOpError("expect all padding values to be >= 0, got ")
770 << padding;
771
772 llvm::ArrayRef<int64_t> strides = op.getStride();
773 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
774 return op.emitOpError("expect all stride values to be >= 1, got ")
775 << strides;
776
777 llvm::ArrayRef<int64_t> dilations = op.getDilation();
778 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
779 return op.emitOpError("expect all dilation values to be >= 1, got ")
780 << dilations;
781
782 const RankedTensorType outputType =
783 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
784 if (!outputType)
785 // Skip following checks if output is not ranked
786 return success();
787
788 const RankedTensorType inputType =
789 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
790 const RankedTensorType weightType =
791 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
792
793 if (inputType && weightType) {
794 const auto verifyOutputSize =
795 [&op](const int64_t inputSize, const int64_t kernelSize,
796 const int64_t outputSize, const int64_t padBefore,
797 const int64_t padAfter, const int64_t stride,
798 const int64_t dilation, const llvm::StringRef dimName,
799 const llvm::StringRef dimAxis,
800 const llvm::StringRef padBeforeName,
801 const llvm::StringRef padAfterName) -> LogicalResult {
802 if (inputSize == ShapedType::kDynamic ||
803 kernelSize == ShapedType::kDynamic)
804 return success();
805
806 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
807
808 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
809 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
810 stride);
811 if (!calculatedOutSizeMinusOne.has_value())
812 return op.emitOpError("expected input_")
813 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
814 << padAfterName << " - (kernel_" << dimName
815 << " - 1) * dilation_" << dimAxis
816 << " to be wholly divisible by stride_" << dimAxis << ", got ("
817 << inputSize << " - 1 + " << padBefore << " + " << padAfter
818 << " - (" << kernelSize << " - 1) * " << dilation << ") / "
819 << stride;
820
821 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
822 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
823 return op.emitOpError("calculated output ")
824 << dimName << " did not match expected: "
825 << "calculated=" << calculatedOutSize
826 << ", expected=" << outputSize;
827
828 return success();
829 };
830
831 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
832 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
833 if (failed(verifyOutputSize(
834 inputType.getDimSize(1), weightType.getDimSize(1),
835 outputType.getDimSize(1), padding[0], padding[1], strides[0],
836 dilations[0], "height", "y", "top", "bottom")))
837 return failure();
838
839 if (failed(verifyOutputSize(
840 inputType.getDimSize(2), weightType.getDimSize(2),
841 outputType.getDimSize(2), padding[2], padding[3], strides[1],
842 dilations[1], "width", "x", "left", "right")))
843 return failure();
844 }
845
846 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
847 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
848 if (failed(verifyOutputSize(
849 inputType.getDimSize(1), weightType.getDimSize(0),
850 outputType.getDimSize(1), padding[0], padding[1], strides[0],
851 dilations[0], "height", "y", "top", "bottom")))
852 return failure();
853
854 if (failed(verifyOutputSize(
855 inputType.getDimSize(2), weightType.getDimSize(1),
856 outputType.getDimSize(2), padding[2], padding[3], strides[1],
857 dilations[1], "width", "x", "left", "right")))
858 return failure();
859 }
860
861 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
862 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
863 if (failed(verifyOutputSize(
864 inputType.getDimSize(1), weightType.getDimSize(1),
865 outputType.getDimSize(1), padding[0], padding[1], strides[0],
866 dilations[0], "depth", "d", "front", "back")))
867 return failure();
868
869 if (failed(verifyOutputSize(
870 inputType.getDimSize(2), weightType.getDimSize(2),
871 outputType.getDimSize(2), padding[2], padding[3], strides[1],
872 dilations[1], "height", "y", "top", "bottom")))
873 return failure();
874
875 if (failed(verifyOutputSize(
876 inputType.getDimSize(3), weightType.getDimSize(3),
877 outputType.getDimSize(3), padding[4], padding[5], strides[2],
878 dilations[2], "width", "x", "left", "right")))
879 return failure();
880 }
881 }
882
883 const RankedTensorType biasType =
884 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
885 if (!biasType)
886 // Skip following checks if bias is not ranked
887 return success();
888
889 const int64_t biasChannels = biasType.getDimSize(0);
890 const int64_t outputChannels =
891 outputType.getDimSize(outputType.getRank() - 1);
892 if (biasChannels == ShapedType::kDynamic ||
893 outputChannels == ShapedType::kDynamic)
894 // Skip following checks if biasChannels or outputChannels is dynamic dim
895 return success();
896
897 if (biasChannels != outputChannels && biasChannels != 1)
898 return op.emitOpError(
899 "bias channels expected to be equal to output channels (")
900 << outputChannels << ") or 1, got " << biasChannels;
901
902 return success();
903}
904
905// Verify whether same type and shape of the given two types.
906static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
907 StringRef name1, Type type2,
908 StringRef name2) {
909 auto shapeType1 = dyn_cast<ShapedType>(type1);
910 auto shapeType2 = dyn_cast<ShapedType>(type2);
911 if (!shapeType1 || !shapeType2)
912 return failure();
913
914 auto elemType1 = shapeType1.getElementType();
915 auto elemType2 = shapeType2.getElementType();
916 if (elemType1 != elemType2)
917 return op->emitOpError()
918 << "require same element type for " << name1 << " (" << elemType1
919 << ") and " << name2 << " (" << elemType2 << ")";
920
921 if (failed(verifyCompatibleShape(type1, type2)))
922 return op->emitOpError()
923 << "require same shapes for " << name1 << " (" << type1 << ") and "
924 << name2 << " (" << type2 << ")";
925
926 return success();
927}
928
929// Verify whether same length, type, and shape of the given two tensor lists.
930static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
931 StringRef name1,
932 ValueRange list2,
933 StringRef name2) {
934 if (list1.size() != list2.size())
935 return op->emitOpError()
936 << "require same number of values in " << name1 << " ("
937 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
938
939 for (auto [type1, type2] :
940 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
941 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
942 return failure();
943 }
944
945 return success();
946}
947
948static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
949 ShapeAdaptor shapeAdaptor(type);
950 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
951 return success();
952
953 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
954}
955
956template <typename T>
957static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
958 Operation *symTableOp =
959 op->template getParentWithTrait<OpTrait::SymbolTable>();
960 if (!symTableOp)
961 // If the operation is not the scope of a symbol table, we cannot
962 // verify it against it's declaration.
963 return success();
964
965 SymbolTable symTable(symTableOp);
966 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
967
968 // Verify prior declaration
969 if (!varOp)
970 return op->emitOpError("'")
971 << op.getName() << "' has not been declared by 'tosa.variable'";
972
973 // Verify type and shape
974 auto variableType = getVariableType(varOp);
975 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
976 "the input tensor")
977 .failed())
978 return failure();
979 return success();
980}
981
982// verify that inType and outType have same element types
983template <typename T>
984static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
985 StringRef aName = "input",
986 StringRef bName = "output") {
987 auto aTType = llvm::dyn_cast<TensorType>(aType);
988 auto bTType = llvm::dyn_cast<TensorType>(bType);
989 if (!aTType) {
990 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
991 return failure();
992 }
993 if (!bTType) {
994 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
995 return failure();
996 }
997 auto aElementType = aTType.getElementType();
998 auto bElementType = bTType.getElementType();
999 auto aQuantType =
1000 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1001 auto bQuantType =
1002 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1003 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1004 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1005 aElementType != bElementType) {
1006 // only check if both element types are int/index/float/UniformQuantized
1007 // eg, not sure how to check quant::QuantizedType
1008 // this happens in test_conv2d_q_grouped_convolution in
1009 // tfl-to-tosa-pipeline.mlir
1010 op.emitOpError("expect ")
1011 << aName << " and " << bName << " to have same element type, got "
1012 << aElementType << " and " << bElementType;
1013 return failure();
1014 }
1015 return success();
1016}
1017
1018LogicalResult tosa::ArgMaxOp::verify() {
1019 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1020
1021 // Ensure output is of 32-bit integer
1022 if (const auto resultETy = resultType.getElementType();
1023 !resultETy.isIntOrIndex())
1024 return emitOpError("result tensor is not of integer type");
1025
1026 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1027 if (!inputType.hasRank())
1028 return success();
1029
1030 // Ensure axis is within the tensor rank
1031 const int64_t axis = getAxisAttr().getInt();
1032 if (((axis < 0) || axis >= inputType.getRank()))
1033 return emitOpError("specified axis is outside the rank of the tensor");
1034
1035 if (!resultType.hasRank())
1036 return success();
1037
1038 const ArrayRef<int64_t> inputShape = inputType.getShape();
1039 const ArrayRef<int64_t> outputShape = resultType.getShape();
1040 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1041 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1042 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1043 return emitOpError("expected output shape '")
1044 << expectedOutputShape << "', got '" << outputShape << "'";
1045
1046 return success();
1047}
1048
1049template <typename T>
1050static LogicalResult verifyPoolingOp(T op) {
1051 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1052 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1053 return op.emitOpError("expect all kernel values to be >= 1, got ")
1054 << kernel;
1055
1056 const llvm::ArrayRef<int64_t> strides = op.getStride();
1057 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1058 return op.emitOpError("expect all stride values to be >= 1, got ")
1059 << strides;
1060
1061 const llvm::ArrayRef<int64_t> padding = op.getPad();
1062 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1063 return op.emitOpError("expect all padding values to be >= 0, got ")
1064 << padding;
1065
1066 // Padding must be less than kernel size to avoid a divide-by-zero
1067 const int64_t kernelX = kernel[1];
1068 const int64_t padLeft = padding[2];
1069 const int64_t padRight = padding[3];
1070 if (padRight >= kernelX || padLeft >= kernelX)
1071 return op.emitOpError("expected left/right padding to be less than the "
1072 "width of the kernel, got pad_left=")
1073 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1074
1075 const int64_t kernelY = kernel[0];
1076 const int64_t padTop = padding[0];
1077 const int64_t padBottom = padding[1];
1078 if (padTop >= kernelY || padBottom >= kernelY)
1079 return op.emitOpError("expected top/bottom padding to be less than the "
1080 "height of the kernel, got pad_top=")
1081 << padTop << ", pad_bottom=" << padBottom
1082 << ", kernel_y=" << kernelY;
1083
1084 const auto inputType =
1085 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1086 const auto outputType =
1087 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1088 if (!inputType || !outputType)
1089 return success();
1090
1091 const auto verifyOutputSize =
1092 [&op](const int64_t inputSize, const int64_t outputSize,
1093 const int64_t kernelSize, const int64_t strideSize,
1094 const int64_t padBefore, const int64_t padAfter,
1095 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1096 const llvm::StringRef padBeforeName,
1097 const llvm::StringRef padAfterName) -> LogicalResult {
1098 if (ShapedType::isDynamic(inputSize))
1099 return success();
1100
1101 const std::optional<int64_t> calculatedOutSizeMinusOne =
1102 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1103 if (!calculatedOutSizeMinusOne.has_value())
1104 return op.emitOpError("expected input_")
1105 << dimName << " + pad_" << padBeforeName << " + pad_"
1106 << padAfterName << " - kernel_" << dimAxis
1107 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1108 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1109 << kernelSize << ") / " << strideSize;
1110
1111 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1112 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1113 return op.emitOpError("calculated output ")
1114 << dimName << " did not match expected: "
1115 << "calculated=" << calculatedOutSize
1116 << ", expected=" << outputSize;
1117
1118 return success();
1119 };
1120
1121 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1122 kernel[0], strides[0], padding[0], padding[1],
1123 "height", "y", "top", "bottom")))
1124 return failure();
1125
1126 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1127 kernel[1], strides[1], padding[2], padding[3],
1128 "width", "x", "left", "right")))
1129 return failure();
1130
1131 return success();
1132}
1133
1134LogicalResult tosa::AvgPool2dOp::verify() {
1135 if (failed(verifyPoolingOp(*this)))
1136 return failure();
1137
1138 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1139 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1140 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1141 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1142
1143 auto accType = getAccType();
1144 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1145 return emitOpError("accumulator type for integer tensor is not i32");
1146
1147 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1148 return emitOpError("accumulator type for f16 tensor is not f16/f32");
1149
1150 if (inputETy.isBF16() && !accType.isF32())
1151 return emitOpError("accumulator type for bf16 tensor is not f32");
1152
1153 if (inputETy.isF32() && !accType.isF32())
1154 return emitOpError("accumulator type for f32 tensor is not f32");
1155
1156 if (inputETy != inputZpETy)
1157 return emitOpError("expect both input and its zero point are the same "
1158 "element type, got ")
1159 << inputETy << " and " << inputZpETy;
1160
1161 if (resultETy != outputZpETy)
1162 return emitOpError("expect both output and its zero point are the same "
1163 "element type, got ")
1164 << resultETy << " and " << outputZpETy;
1165
1166 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1167 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1168 return failure();
1169
1170 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1171 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1172 return failure();
1173
1174 return success();
1175}
1176
1177LogicalResult tosa::ClampOp::verify() {
1178 mlir::Type inputETy =
1179 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1180 if (auto quantType =
1181 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1182 inputETy = quantType.getStorageType();
1183 }
1184 mlir::Type outputETy =
1185 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1186 if (auto quantType =
1187 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1188 outputETy = quantType.getStorageType();
1189 }
1190 if (inputETy != outputETy)
1191 return emitOpError("input/output element types are incompatible.");
1192
1193 auto maxValAttr = getMaxValAttr();
1194 auto minValAttr = getMinValAttr();
1195
1196 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1197
1198 if (inputETy.isInteger(dataTypeBitWidth)) {
1199 // if input datatype is integer, check that the min_val/max_val attributes
1200 // are integer attributes, and that their type is the same as the input's
1201 // datatype
1202 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1203 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1204 if (!intMaxValAttr || !intMinValAttr ||
1205 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1206 (intMaxValAttr.getType() != inputETy))
1207 return emitOpError("min/max attributes types are incompatible with "
1208 "input/output element types.");
1209
1210 const bool isUnsigned = inputETy.isUnsignedInteger();
1211 const bool isBoolean = inputETy.isInteger(1);
1212 const APInt minVal = intMinValAttr.getValue();
1213 const APInt maxVal = intMaxValAttr.getValue();
1214 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1215 return emitOpError("expected min_val <= max_val, got min_val=")
1216 << minValAttr << ", max_val=" << maxValAttr;
1217 } else {
1218 // otherwise, input datatype is float, check that the min_val/max_val
1219 // attributes share the same type and that their type is the same as the
1220 // input's datatype
1221 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1222 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1223 if (!floatMaxValAttr || !floatMinValAttr ||
1224 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1225 (floatMaxValAttr.getType() != inputETy))
1226 return emitOpError("min/max attributes types are incompatible with "
1227 "input/output element types.");
1228
1229 const APFloat minVal = floatMinValAttr.getValue();
1230 const APFloat maxVal = floatMaxValAttr.getValue();
1231 if (minVal.isNaN() || maxVal.isNaN())
1232 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1233 << minValAttr << ", max_val=" << maxValAttr;
1234
1235 if (maxVal < minVal)
1236 return emitOpError("expected min_val <= max_val, got min_val=")
1237 << minValAttr << ", max_val=" << maxValAttr;
1238 }
1239
1240 return success();
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// TOSA Operator Quantization Builders.
1245//===----------------------------------------------------------------------===//
1246
1247/// This builder is called on all convolution operators except TransposeConv,
1248/// which has specialized output shape semantics. The builder also defines the
1249/// bitwidth of the output given the bit width of the input & weight content.
1251 Type outputType, Value input, Value weight,
1252 Value bias, DenseI64ArrayAttr pad,
1253 DenseI64ArrayAttr stride,
1254 DenseI64ArrayAttr dilation,
1255 TypeAttr accType) {
1256 auto zps = createZPsAsConst(builder, input, weight);
1257 result.addOperands({input, weight, bias, zps.first, zps.second});
1258 result.addAttribute("pad", pad);
1259 result.addAttribute("stride", stride);
1260 result.addAttribute("dilation", dilation);
1261 result.addAttribute("acc_type", accType);
1262 Type finalOutputType = outputType;
1263 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1264 if (quantAttr) {
1265 finalOutputType =
1266 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1267 }
1268 result.addTypes(finalOutputType);
1269}
1270
1271/// Handles tosa.transpose_conv2d which has outpad and output shape
1272/// attributes.
1273static void
1275 Type outputType, Value input, Value weight,
1276 Value bias, DenseI64ArrayAttr outpad,
1277 DenseI64ArrayAttr stride, TypeAttr accType) {
1278 auto zps = createZPsAsConst(builder, input, weight);
1279 result.addOperands({input, weight, bias, zps.first, zps.second});
1280 result.addAttribute("out_pad", outpad);
1281 result.addAttribute("stride", stride);
1282 result.addAttribute("acc_type", accType);
1283 Type finalOutputType = outputType;
1284 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1285 if (quantAttr) {
1286 finalOutputType =
1287 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1288 }
1289 result.addTypes(finalOutputType);
1290}
1291
1292/// The tosa.matmul op is also intended to be generated where a fully_connected
1293/// op must be constructed where the weight is not a constant. In this case,
1294/// the fully_connected op must be expressed using matmul.
1295/// TODO: Add link to the leglization document explaining this.
1297 OperationState &result, Type outputType,
1298 Value a, Value b) {
1299 auto zps = createZPsAsConst(builder, a, b);
1300 result.addOperands({a, b, zps.first, zps.second});
1301
1302 Type finalOutputType{outputType};
1303 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1304 auto eType = getStorageElementTypeOrSelf(a.getType());
1305 auto inputBits = eType.getIntOrFloatBitWidth();
1306
1307 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1308 assert(outputShapedType && "Output must be a shaped type");
1309
1310 IntegerType accElementType;
1311 if (inputBits == 16)
1312 accElementType = builder.getIntegerType(48);
1313 else
1314 accElementType = builder.getI32Type();
1315
1316 finalOutputType = outputShapedType.clone(accElementType);
1317 }
1318 result.addTypes(finalOutputType);
1319}
1320
1321/// Both the tosa.avg_pool2d and unary ops use the same
1322/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1323/// has additional parameters not part of the unary ops.
1324static void
1326 Type outputType, Value input,
1327 DenseArrayAttr kernel, DenseArrayAttr stride,
1328 DenseArrayAttr pad, TypeAttr accType) {
1329 const Location loc{result.location};
1330 int64_t inputZp{0};
1331 int64_t outputZp{0};
1332
1333 if (auto quantAttr =
1334 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1335 inputZp = quantAttr.getInputZp();
1336 outputZp = quantAttr.getOutputZp();
1337 }
1338 const std::optional<Value> inputZpOp =
1339 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1340 if (!inputZpOp) {
1341 (void)emitError(
1342 loc,
1343 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1344 }
1345 const std::optional<Value> outputZpOp =
1346 createZeroPointTensor(builder, loc, outputType, outputZp);
1347 if (!outputZpOp) {
1348 (void)emitError(loc, "Failed to create output zero point tensor for "
1349 "quantized AVG_POOL2D op");
1350 }
1351
1352 if (inputZpOp && outputZpOp) {
1353 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1354 } else {
1355 // failed to create one or more zero points above: just add input as
1356 // operands this will trigger error in building the op because of missing
1357 // zero points
1358 result.addOperands({input});
1359 }
1360 result.addAttribute("kernel", kernel);
1361 result.addAttribute("stride", stride);
1362 result.addAttribute("pad", pad);
1363 result.addAttribute("acc_type", accType);
1364 result.types.push_back(outputType);
1365}
1366
1367/// This builder is called on single-parameter negate operator
1368/// to construct input and output zero points based on their
1369/// types.
1371 OperationState &result, Type outputType,
1372 Value input) {
1373 const Location loc{result.location};
1374 int64_t input1Zp{0};
1375 int64_t outputZp{0};
1376 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1377 if (quantAttr) {
1378 input1Zp = quantAttr.getInputZp();
1379 outputZp = quantAttr.getOutputZp();
1380 }
1381 const std::optional<Value> input1ZpOp =
1382 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1383 if (!input1ZpOp) {
1384 (void)emitError(
1385 loc, "Failed to create input1 zero point for quantized NEGATE op");
1386 }
1387
1388 const std::optional<Value> outputZpOp =
1389 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1390 if (!outputZpOp) {
1391 (void)emitError(
1392 loc, "Failed to create output zero point for quantized NEGATE op");
1393 }
1394
1395 if (input1ZpOp && outputZpOp) {
1396 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1397 } else {
1398 // failed to create one or more zero points above: just add input as
1399 // operands. This will trigger error in building the op because of
1400 // missing zero points
1401 result.addOperands({input});
1402 }
1403
1404 result.types.push_back(outputType);
1405}
1406
1407/// This builder is called on TOSA pad operator that needs to create its own
1408/// OptionalAttr quantization_attr parameter to scale the padding values
1409/// correctly. No pad_const is interpreted as zero-padding.
1411 Type outputType, Value input,
1412 Value paddings) {
1413 const Location loc{result.location};
1414 int32_t zp{0};
1415 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1416 if (quantAttr) {
1417 zp = static_cast<int32_t>(quantAttr.getInputZp());
1418 }
1419 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1420 result.addOperands({input, paddings, padConstOp});
1421 result.types.push_back(outputType);
1422}
1423
1425 StringRef name, Type variableType,
1426 Attribute initialValue) {
1427 const Location loc{result.location};
1428 auto nameAttr = builder.getStringAttr(name);
1429
1430 auto shapedType = dyn_cast<ShapedType>(variableType);
1431 if (!shapedType) {
1432 (void)emitError(loc, "variable type must be a shaped type");
1433 return;
1434 }
1435 if (!shapedType.hasRank()) {
1436 (void)emitError(loc, "variable type must be a ranked type");
1437 return;
1438 }
1439
1440 auto elementType = shapedType.getElementType();
1441 auto elementTypeAttr = TypeAttr::get(elementType);
1442 ArrayRef<int64_t> shape = shapedType.getShape();
1443 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1444
1445 result.addAttribute("sym_name", nameAttr);
1446 result.addAttribute("var_shape", varShapeAttr);
1447 result.addAttribute("type", elementTypeAttr);
1448 result.addAttribute("initial_value", initialValue);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// TOSA Operator Return Type Inference.
1453//===----------------------------------------------------------------------===//
1454
1455static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1456 SmallVector<int64_t> &outShape) {
1457 int64_t outRank = 0;
1458 for (int i = 0, e = operands.size(); i != e; ++i) {
1459 auto shape = operands.getShape(i);
1460 if (!shape.hasRank()) {
1461 // TODO(jennik): Update function to have better case handling for
1462 // invalid operands and for ranked tensors.
1463 return failure();
1464 }
1465 outRank = std::max<int64_t>(outRank, shape.getRank());
1466 }
1467
1468 outShape.resize(outRank, 1);
1469
1470 for (int i = 0, e = operands.size(); i != e; ++i) {
1471 auto shape = operands.getShape(i);
1472 auto rankDiff = outShape.size() - shape.getRank();
1473
1474 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1475 auto dim1 = outShape[i + rankDiff];
1476 auto dim2 = shape.getDimSize(i);
1477 auto resolvedDim = dim1;
1478
1479 if (dim1 == 1) {
1480 resolvedDim = dim2;
1481 } else if (dim2 == 1) {
1482 resolvedDim = dim1;
1483 } else if (dim1 != dim2) {
1484 return failure();
1485 }
1486 outShape[i + rankDiff] = resolvedDim;
1487 }
1488 }
1489
1490 return success();
1491}
1492
1493LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1494 MLIRContext *context, ::std::optional<Location> location,
1495 ArgMaxOp::Adaptor adaptor,
1496 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1497 ShapeAdaptor inputShape(adaptor.getInput().getType());
1498 IntegerAttr axis = adaptor.getProperties().axis;
1499 int32_t axisVal = axis.getValue().getSExtValue();
1500
1501 if (!inputShape.hasRank()) {
1502 inferredReturnShapes.push_back(ShapedTypeComponents());
1503 return success();
1504 }
1505
1506 SmallVector<int64_t> outShape;
1507 outShape.reserve(inputShape.getRank() - 1);
1508 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1509 if (i == axisVal)
1510 continue;
1511 outShape.push_back(inputShape.getDimSize(i));
1512 }
1513
1514 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1515 return success();
1516}
1517
1518LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1519 MLIRContext *context, ::std::optional<Location> location,
1520 RFFT2dOp::Adaptor adaptor,
1521 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1522 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1523
1524 if (!inputShape.hasRank())
1525 return failure();
1526
1527 llvm::SmallVector<int64_t> outputShape;
1528 outputShape.resize(3, ShapedType::kDynamic);
1529 outputShape[0] = inputShape.getDimSize(0);
1530 outputShape[1] = inputShape.getDimSize(1);
1531 int64_t inWidth = inputShape.getDimSize(2);
1532
1533 // Note that we can support this calculation symbolically
1534 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1535 if (inWidth != ShapedType::kDynamic)
1536 outputShape[2] = inWidth / 2 + 1;
1537
1538 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1539 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1540
1541 return success();
1542}
1543
1544static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1545 const llvm::StringRef dimName) {
1546 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1547 if (!isPowerOfTwo)
1548 return op->emitOpError("expected ")
1549 << dimName << " to be a power of two, got " << dimSize;
1550
1551 return success();
1552}
1553
1554LogicalResult tosa::RFFT2dOp::verify() {
1555 const auto outputTypes = getResultTypes();
1556 if (failed(verifyCompatibleShapes(outputTypes)))
1557 return emitOpError("expected output shapes to match, got ") << outputTypes;
1558
1559 const auto inputType =
1560 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1561 if (!inputType)
1562 return success();
1563
1564 const int64_t height = inputType.getDimSize(1);
1565 if (ShapedType::isStatic(height) &&
1566 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1567 return failure();
1568
1569 const int64_t width = inputType.getDimSize(2);
1570 if (ShapedType::isStatic(width) &&
1571 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1572 return failure();
1573
1574 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1575 if (!outputType)
1576 return success();
1577
1578 // Batch and height input/output dimensions should match
1579 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1580 outputType.getShape().drop_back())))
1581 return emitOpError("expected batch and height dimensions of input/output "
1582 "to match, got input=")
1583 << inputType << " output=" << outputType;
1584
1585 // Output width dimension expected to be input_width / 2 + 1
1586 const int64_t outputWidth = outputType.getDimSize(2);
1587 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1588 (outputWidth != (width / 2) + 1))
1589 return emitOpError(
1590 "expected output width to be equal to input_width / 2 + 1, got ")
1591 << outputWidth;
1592
1593 return success();
1594}
1595
1596LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1597 MLIRContext *context, ::std::optional<Location> location,
1598 FFT2dOp::Adaptor adaptor,
1599 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1600 inferredReturnShapes.push_back(
1601 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1602 inferredReturnShapes.push_back(
1603 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1604 return success();
1605}
1606
1607LogicalResult tosa::FFT2dOp::verify() {
1608 const auto inputRealType =
1609 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1610 const auto inputImagType =
1611 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1612 if (!inputRealType || !inputImagType)
1613 return success();
1614
1615 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1616 return ShapedType::isDynamic(a) ? a : b;
1617 };
1618
1619 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1620 inputImagType.getDimSize(1));
1621 if (ShapedType::isStatic(height) &&
1622 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1623 return failure();
1624
1625 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1626 inputImagType.getDimSize(2));
1627 if (ShapedType::isStatic(width) &&
1628 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1629 return failure();
1630
1631 return success();
1632}
1633
1634LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1635 MLIRContext *context, ::std::optional<Location> location,
1636 ConcatOp::Adaptor adaptor,
1637 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1638 // Infer all dimension sizes by reducing based on inputs.
1639 const Properties &prop = adaptor.getProperties();
1640 int32_t axis = prop.axis.getValue().getSExtValue();
1641 llvm::SmallVector<int64_t> outputShape;
1642 bool hasRankedInput = false;
1643 for (auto operand : adaptor.getOperands()) {
1644 ShapeAdaptor operandShape(operand.getType());
1645 if (!operandShape.hasRank())
1646 continue;
1647
1648 // Copy the Operand's rank.
1649 if (!hasRankedInput)
1650 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1651
1652 // Copy shapes until the dim is non-dynamic.
1653 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1654 if (i == axis || operandShape.isDynamicDim(i))
1655 continue;
1656 if (outputShape[i] == ShapedType::kDynamic)
1657 outputShape[i] = operandShape.getDimSize(i);
1658 if (outputShape[i] != operandShape.getDimSize(i))
1659 return emitOptionalError(location,
1660 "Cannot concat tensors with different sizes"
1661 " on the non-axis dimension ",
1662 i);
1663 }
1664
1665 hasRankedInput = true;
1666 }
1667
1668 if (adaptor.getInput1().empty())
1669 return failure();
1670
1671 Type inputType =
1672 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1673 if (!hasRankedInput) {
1674 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1675 return success();
1676 }
1677
1678 // Determine the dimension size along the concatenation axis.
1679 int64_t concatDimSize = 0;
1680 for (auto operand : adaptor.getOperands()) {
1681 ShapeAdaptor operandShape(operand.getType());
1682
1683 // We need to know the length of the concatenation axis of all inputs to
1684 // determine the dimension size of the output shape.
1685 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1686 concatDimSize = ShapedType::kDynamic;
1687 break;
1688 }
1689
1690 concatDimSize += operandShape.getDimSize(axis);
1691 }
1692
1693 outputShape[axis] = concatDimSize;
1694
1695 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1696 return success();
1697}
1698
1699LogicalResult tosa::ConcatOp::verify() {
1700 // check that each input has same element type as output
1701 auto outType = getOutput().getType();
1702 const Operation::operand_range inputList = getInput1();
1703
1704 // Check there is at least one input
1705 if (inputList.empty())
1706 return emitOpError("expect at least one input");
1707
1708 if (!llvm::all_of(inputList, [&](auto input) {
1709 return succeeded(verifySameElementTypes(
1710 *this, /* inType = */ input.getType(), outType));
1711 })) {
1712 return failure();
1713 }
1714
1715 const int32_t axis = getAxis();
1716 ShapeAdaptor firstRankedInputShape = nullptr;
1717 for (const auto &input : inputList) {
1718 const Type inputType = input.getType();
1719 ShapeAdaptor currShape(inputType);
1720 if (currShape.hasRank()) {
1721 firstRankedInputShape = currShape;
1722 // Check axis is in expected range
1723 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1724 return emitOpError("expect axis to be within range 0 < axis < "
1725 "rank(input1[firstRankedTensorIdx]), got ")
1726 << axis;
1727 break;
1728 }
1729 }
1730
1731 const auto allOperandsHasRank = [](const Value input) {
1732 return ShapeAdaptor(input.getType()).hasRank();
1733 };
1734 if (llvm::all_of(inputList, allOperandsHasRank)) {
1735 const int64_t firstInputRank = firstRankedInputShape.getRank();
1736
1737 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1738 const ShapeAdaptor inputShape(input.getType());
1739 const int64_t inputRank = inputShape.getRank();
1740 const size_t operandNum = index + 1;
1741
1742 // Check that each operand has the same rank
1743 if (inputRank != firstInputRank)
1744 return emitOpError(
1745 "expect all operands to have the same rank, but got ")
1746 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1747 << operandNum;
1748
1749 // Check non-axis dims match
1750 for (int i = 0; i < inputRank; i++) {
1751 const int64_t inputDim = inputShape.getDimSize(i);
1752 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1753 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1754 inputShape.isDynamicDim(i))
1755 continue;
1756 if (inputDim != firstInputDim)
1757 return emitOpError("expect all operand shapes to have the same sizes "
1758 "on non-axis dimensions, but got ")
1759 << inputDim << " vs " << firstInputDim << " at index " << i
1760 << " on operands 0 and " << operandNum;
1761 }
1762 }
1763
1764 const ShapeAdaptor outputShape(outType);
1765 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1766 return emitOpError("expect output rank to match inputs rank, got ")
1767 << outputShape.getRank() << " vs " << firstInputRank;
1768
1769 // ERROR_IF(axis_sum != shape[axis]);
1770 int64_t axisSum = 0;
1771 for (const auto &input : inputList) {
1772 const ShapeAdaptor inputShape(input.getType());
1773 if (inputShape.isDynamicDim(axis)) {
1774 // make axisSum negative to indicate invalid value
1775 axisSum = -1;
1776 break;
1777 }
1778 axisSum += inputShape.getDimSize(axis);
1779 }
1780
1781 if (axisSum >= 0 && outputShape.hasRank() &&
1782 !outputShape.isDynamicDim(axis) &&
1783 axisSum != outputShape.getDimSize(axis))
1784 return emitOpError("requires sum of axis dimensions of input1 "
1785 "equal to output axis dimension, got ")
1786 << axisSum << " and " << outputShape.getDimSize(axis);
1787 }
1788
1789 return success();
1790}
1791
1792LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1793 MLIRContext *context, ::std::optional<Location> location,
1794 ValueShapeRange operands, DictionaryAttr attributes,
1795 OpaqueProperties properties, RegionRange regions,
1796 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1797 auto elementType = IntegerType::get(context, /*width=*/1);
1798
1800 if (resolveBroadcastShape(operands, outShape).failed()) {
1801 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1802 return success();
1803 }
1804
1805 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1806 return success();
1807}
1808
1809bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1810 if (l.size() != r.size() || l.size() != 1)
1811 return false;
1812 return succeeded(verifyCompatibleShape(l[0], r[0]));
1813}
1814
1815LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1816 MLIRContext *context, ::std::optional<Location> location,
1817 MatMulOp::Adaptor adaptor,
1818 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1819 ShapeAdaptor lhsShape(adaptor.getA().getType());
1820 ShapeAdaptor rhsShape(adaptor.getB().getType());
1821
1822 // All shapes are dynamic.
1823 SmallVector<int64_t> outShape;
1824 outShape.resize(3, ShapedType::kDynamic);
1825
1826 if (lhsShape.hasRank()) {
1827 outShape[0] = lhsShape.getDimSize(0);
1828 outShape[1] = lhsShape.getDimSize(1);
1829 }
1830
1831 if (rhsShape.hasRank()) {
1832 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1833 : outShape[0];
1834 outShape[2] = rhsShape.getDimSize(2);
1835 }
1836
1837 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1838 return success();
1839}
1840
1841LogicalResult MatMulOp::verify() {
1842 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1843 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1844
1845 // Must be shaped tensor types
1846 if (!aType)
1847 return emitOpError("expect a shaped tensor for input a, got ")
1848 << getA().getType();
1849
1850 if (!bType)
1851 return emitOpError("expect a shaped tensor for input b, got ")
1852 << getB().getType();
1853
1854 auto aElementType = aType.getElementType();
1855 auto bElementType = bType.getElementType();
1856
1857 auto aQuantizedEType =
1858 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1859 auto bQuantizedEType =
1860 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1861
1862 if (aQuantizedEType || bQuantizedEType) {
1863 if (!aQuantizedEType || !bQuantizedEType) {
1864 return emitOpError("expect operands to be both quantized or both not "
1865 "quantized, got ")
1866 << aElementType << " and " << bElementType;
1867 }
1868 // both a and b have quantized element types
1869 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1870 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1871 if (aQuantWidth != bQuantWidth) {
1872 return emitOpError("expect quantized operands to have same widths, got ")
1873 << aQuantWidth << " and " << bQuantWidth;
1874 }
1875 }
1876
1877 // check a_zp and b_zp
1878 auto aEType = getStorageElementTypeOrSelf(aType);
1879 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1880 if (aEType != aZpEType) {
1881 return emitOpError("expect input a and a_zp have the same "
1882 "element type, got ")
1883 << aEType << " and " << aZpEType;
1884 }
1885
1886 auto bEType = getStorageElementTypeOrSelf(bType);
1887 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1888 if (bEType != bZpEType) {
1889 return emitOpError("expect input b and b_zp have the same "
1890 "element type, got ")
1891 << bEType << " and " << bZpEType;
1892 }
1893
1894 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1895 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1896 return failure();
1897
1898 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1899 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1900 return failure();
1901
1902 return success();
1903}
1904
1905LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1906 MLIRContext *context, ::std::optional<Location> location,
1907 MatmulTBlockScaledOp::Adaptor adaptor,
1908 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1909 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1910
1911 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1912 if (aDataShape.hasRank()) {
1913 outShape[0] = aDataShape.getDimSize(0);
1914 outShape[1] = aDataShape.getDimSize(1);
1915 }
1916
1917 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1918 if (aScaleShape.hasRank()) {
1919 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1920 : outShape[0];
1921 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1922 : outShape[1];
1923 }
1924
1925 // If B batch size is 1, it is broadcast across A's batch size
1926 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1927 if (bDataShape.hasRank()) {
1928 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1929 if (bDataBatchSize != 1)
1930 outShape[0] =
1931 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1932 outShape[2] = bDataShape.getDimSize(1);
1933 }
1934
1935 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1936 if (bScaleShape.hasRank()) {
1937 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1938 if (bScaleBatchSize != 1)
1939 outShape[0] =
1940 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1941 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1942 : outShape[2];
1943 }
1944
1945 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1946 return success();
1947}
1948
1949LogicalResult MatmulTBlockScaledOp::verify() {
1950 // Verify same input data types
1951 const Type aDataType = getAData().getType();
1952 const Type bDataType = getBData().getType();
1953 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1954 "B_data")))
1955 return failure();
1956
1957 auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
1958 const StringRef operandName,
1959 const StringRef dimName) -> LogicalResult {
1960 if (ShapedType::isDynamic(currDim)) {
1961 currDim = newDim;
1962 return success();
1963 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
1964 return emitOpError("expected ")
1965 << dimName << " of " << operandName << " to match size " << currDim
1966 << ", got " << newDim;
1967 }
1968 return success();
1969 };
1970
1971 // Verify input shape compatibility
1972 int64_t N = ShapedType::kDynamic;
1973 int64_t D = ShapedType::kDynamic;
1974 int64_t H = ShapedType::kDynamic;
1975 int64_t W = ShapedType::kDynamic;
1976 int64_t C = ShapedType::kDynamic;
1977 int64_t multiplesOfC = ShapedType::kDynamic;
1978
1979 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
1980 if (aDataShape.hasRank()) {
1981 N = aDataShape.getDimSize(0);
1982 H = aDataShape.getDimSize(1);
1983 C = aDataShape.getDimSize(2);
1984 }
1985
1986 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
1987 if (aScaleShape.hasRank()) {
1988 if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
1989 "batch")) ||
1990 failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
1991 "height")))
1992 return failure();
1993 multiplesOfC = aScaleShape.getDimSize(2);
1994 }
1995
1996 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
1997 if (bDataShape.hasRank()) {
1998 if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
1999 "batch")) ||
2000 failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
2001 "channels")))
2002 return failure();
2003 W = bDataShape.getDimSize(1);
2004 }
2005
2006 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2007 if (bScaleShape.hasRank()) {
2008 if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
2009 "batch")) ||
2010 failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
2011 "width")) ||
2012 failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
2013 "b_scale", "C/block_size")))
2014 return failure();
2015 }
2016
2017 // Verify batch size is broadcast compatible
2018 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2019 return emitOpError("expect B matrix batch size to be broadcast compatible "
2020 "with A, got D=")
2021 << D << " vs N=" << N;
2022
2023 // Verify C is a multiple of block size
2024 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2025 if (ShapedType::isStatic(C) && C % blockSize != 0)
2026 return emitOpError("expect C to be a multiple of block size, got C=")
2027 << C << ", block_size=" << blockSize;
2028
2029 // Verify multiplesOfC is C / block size
2030 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2031 multiplesOfC != C / blockSize)
2032 return emitOpError(
2033 "expect scale operands dimension 2 to equal C/block_size (")
2034 << C << "/" << blockSize << ")"
2035 << ", got " << multiplesOfC;
2036
2037 // Verify output shape
2038 N = ShapedType::isDynamic(N) ? D : N;
2039 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2040 const auto outputType = cast<ShapedType>(getResult().getType());
2041 if (outputType.hasRank() &&
2042 failed(
2043 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2044 InFlightDiagnostic opError = emitOpError("expected output shape ");
2045 auto stringifyDim = [&](int64_t d) {
2046 if (ShapedType::isDynamic(d))
2047 opError << "?";
2048 else
2049 opError << d;
2050 };
2051 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2052 opError << " to be compatible with expected output shape ";
2053 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2054 return opError;
2055 }
2056
2057 return success();
2058}
2059
2060LogicalResult tosa::PadOp::inferReturnTypeComponents(
2061 MLIRContext *context, ::std::optional<Location> location,
2062 PadOp::Adaptor adaptor,
2063 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2064 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2065 auto paddingRank =
2066 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2067 SmallVector<int64_t> outputShape;
2068
2069 // If the input rank is unknown, we can infer the output rank using the
2070 // padding shape's rank divided by 2.
2071 if (!inputShape.hasRank()) {
2072 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2073 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2074 return success();
2075 }
2076
2077 SmallVector<int64_t> paddingValues;
2078 // If the paddings value is not a constant, all dimensions must be dynamic.
2079 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2080 paddingValues)) {
2081 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2082 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2083 return success();
2084 }
2085
2086 outputShape.reserve(inputShape.getRank());
2087 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2088 if (inputShape.isDynamicDim(i)) {
2089 outputShape.push_back(ShapedType::kDynamic);
2090 continue;
2091 }
2092 auto padFront = paddingValues[i * 2];
2093 auto padBack = paddingValues[i * 2 + 1];
2094 if (padFront < 0 || padBack < 0) {
2095 // if either padding for dim i is -1, output dim is unknown
2096 outputShape.push_back(ShapedType::kDynamic);
2097 continue;
2098 }
2099
2100 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2101 }
2102
2103 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2104 return success();
2105}
2106
2107LogicalResult tosa::PadOp::verify() {
2108 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2109 /* outType = */ getOutput().getType())
2110 .failed()) {
2111 return failure();
2112 }
2113
2114 if (auto padConst = getPadConst()) {
2115 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2116 /* outType = */ getOutput().getType())
2117 .failed()) {
2118 return failure();
2119 }
2120 }
2121
2122 RankedTensorType inputType =
2123 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2124 RankedTensorType outputType =
2125 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2126 if (!inputType || !outputType)
2127 return success();
2128
2129 auto inputRank = inputType.getRank();
2130 auto outputRank = outputType.getRank();
2131 if (inputRank != outputRank)
2132 return emitOpError() << "expect same input and output tensor rank, but got "
2133 << "inputRank: " << inputRank
2134 << ", outputRank: " << outputRank;
2135
2136 DenseIntElementsAttr paddingAttr;
2137 if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
2138 return failure();
2139 }
2140
2141 auto paddingValues = paddingAttr.getValues<APInt>();
2142 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2143 return emitOpError() << "padding tensor must have " << inputRank
2144 << " * 2 = " << inputRank * 2 << " elements, but got "
2145 << paddingValues.size();
2146
2147 auto inputShape = inputType.getShape();
2148 auto outputShape = outputType.getShape();
2149
2150 for (int64_t i = 0; i < inputRank; ++i) {
2151 int64_t padStart = paddingValues[i * 2].getSExtValue();
2152 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2153
2154 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2155 return emitOpError()
2156 << "invalid padding values at dimension " << i
2157 << ": values must be non-negative or -1 for dynamic padding, got ["
2158 << padStart << ", " << padEnd << "]";
2159 }
2160
2161 // Skip shape verification for dynamic input/output
2162 if (inputShape[i] == ShapedType::kDynamic ||
2163 outputShape[i] == ShapedType::kDynamic)
2164 continue;
2165
2166 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2167 return emitOpError() << "mismatch in output shape at dimension " << i
2168 << ": expected " << inputShape[i] << " + "
2169 << padStart << " + " << padEnd << " = "
2170 << (inputShape[i] + padStart + padEnd)
2171 << ", but got " << outputShape[i];
2172 }
2173 }
2174
2175 return success();
2176}
2177
2178LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2179 MLIRContext *context, ::std::optional<Location> location,
2180 SliceOp::Adaptor adaptor,
2181 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2182
2183 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2186
2187 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2188 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2189 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2190 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2191 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2192 return success();
2193 }
2194
2195 // if size[i] is -1, all remaining elements in dimension i are included
2196 // in the slice, similar to TF.
2197 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2198 // initialize outputShape to all unknown
2199 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2200 if (inputShape.hasRank()) {
2201 for (size_t i = 0; i < size.size(); i++) {
2202 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2203 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2204 start[i] < inputShape.getDimSize(i))) {
2205 // size[i] is not 0 and not < -1, and start[i] is in valid range
2206 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2207 // input shape has unknown dim[i] - only valid if size[i] > 0
2208 if (size[i] > 0) {
2209 outputShape[i] = size[i];
2210 }
2211 } else {
2212 // input shape has known dim[i]
2213 if (size[i] == -1) {
2214 outputShape[i] = inputShape.getDimSize(i) - start[i];
2215 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2216 // start[i] + size[i] is within bound of input shape's dim[i]
2217 outputShape[i] = size[i];
2218 }
2219 }
2220 }
2221 }
2222 } else {
2223 outputShape = convertToMlirShape(size);
2224 }
2225 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2226 return success();
2227}
2228
2229LogicalResult tosa::SliceOp::verify() {
2230 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2231 /* outType = */ getOutput().getType())
2232 .failed())
2233 return failure();
2234
2235 const ShapeAdaptor inputShape(getInput1().getType());
2236 if (inputShape.hasRank()) {
2237 const auto inputRank = inputShape.getRank();
2238 const ShapeAdaptor outputShape(getOutput().getType());
2239 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2240 return emitOpError(
2241 "expect input1 and output to have the same ranks, got ")
2242 << inputRank << " and " << outputShape.getRank();
2243
2244 const auto startShapeRank =
2245 llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2246 if (inputRank != startShapeRank)
2247 return emitOpError("length of start is not equal to rank of input shape");
2248
2249 const auto sizeShapeRank =
2250 llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2251 if (inputRank != sizeShapeRank)
2252 return emitOpError("length of size is not equal to rank of input shape");
2253 }
2254
2255 return success();
2256}
2257
2258LogicalResult tosa::MulOp::inferReturnTypeComponents(
2259 MLIRContext *context, ::std::optional<Location> location,
2260 ValueShapeRange operands, DictionaryAttr attributes,
2261 OpaqueProperties properties, RegionRange regions,
2262 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2263 // mul op's output shape only depend on input1 and input2, not on shift
2264 ValueShapeRange twoInputs = operands.drop_back();
2266 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2267 inferredReturnShapes.push_back(ShapedTypeComponents());
2268 } else {
2269 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2270 }
2271 return success();
2272}
2273
2274LogicalResult tosa::MulOp::verify() {
2275 const Value output = getOutput();
2276 auto resElemType = getElementTypeOrSelf(output);
2277
2278 // Verify if the element type among operands and result match tosa
2279 // specification.
2280 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2281 IntegerType lhsIntType =
2282 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2283 IntegerType rhsIntType =
2284 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2285 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2286 return emitOpError("requires the same element type for all operands");
2287
2288 // Though the spec requires the element type of result to be i32, a more
2289 // relaxed way is provided at dialect level for easier cooperating with
2290 // other dialects.
2291 if (lhsIntType.getWidth() > resIntType.getWidth())
2292 return emitOpError("invalid data type size for operands or result");
2293
2294 } else {
2295 // For other supported type, the spec requires requires the same element
2296 // type for all operands (excludes `shift` operand) and results.
2297 for (int i = 0; i < 2; ++i) {
2298 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2299 return emitOpError(
2300 "requires the same element type for all operands and results");
2301 }
2302
2303 // verify shift has value 0 for non-integer types
2304 ElementsAttr shift_elem;
2305 if (matchPattern(getShift(), m_Constant(&shift_elem))) {
2306 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2307 if (shift != 0) {
2308 return emitOpError() << "require shift to be 0 for float type";
2309 }
2310 }
2311 }
2312
2313 // Verify the op has same ranks for all main operands (excludes extra operands
2314 // such as shift of mul op, so this is the only difference with the built-in
2315 // `SameOperandsAndResultRank` trait) and results types, if known.
2316 TypeRange operandTypes = getOperandTypes();
2317 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2318 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2319
2320 const bool aHasRank = aType.hasRank();
2321 const bool bHasRank = bType.hasRank();
2322 if (aHasRank && bHasRank) {
2323 const int64_t aRank = aType.getRank();
2324 const int64_t bRank = bType.getRank();
2325 if (aRank != bRank)
2326 return emitOpError("a and b operands don't have matching ranks, got ")
2327 << aRank << " and " << bRank;
2328
2329 // check for broadcast compatible shapes
2330 SmallVector<int64_t> resultShape;
2332 aType.getShape(), bType.getShape(), resultShape))
2333 return emitOpError("a and b operands don't have broadcast-compatible "
2334 "shapes, got ")
2335 << aType << " and " << bType;
2336 }
2337
2338 ShapedType resultType = cast<ShapedType>(output.getType());
2339 if (!resultType.hasRank())
2340 return success();
2341
2342 const int64_t resultRank = resultType.getRank();
2343 if (aHasRank && resultRank != aType.getRank())
2344 return emitOpError("result type has different rank than a, got ")
2345 << resultRank << " vs " << aType.getRank();
2346 if (bHasRank && resultRank != bType.getRank())
2347 return emitOpError("result type has different rank than b, got ")
2348 << resultRank << " vs " << bType.getRank();
2349
2350 return success();
2351}
2352
2353LogicalResult tosa::TableOp::inferReturnTypeComponents(
2354 MLIRContext *context, ::std::optional<Location> location,
2355 TableOp::Adaptor adaptor,
2356 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2357 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2358
2359 if (!inputShape.hasRank()) {
2360 inferredReturnShapes.push_back(ShapedTypeComponents());
2361 return success();
2362 }
2363
2364 inferredReturnShapes.resize(1);
2365 inputShape.getDims(inferredReturnShapes[0]);
2366 return success();
2367}
2368
2369LogicalResult tosa::TableOp::verify() {
2370 const TensorType inputType = getInput1().getType();
2371 const TensorType outputType = getOutput().getType();
2372
2373 if (!inputType.hasRank() || !outputType.hasRank())
2374 return success();
2375
2376 if (inputType.getRank() != outputType.getRank())
2377 return emitOpError()
2378 << "expected input tensor rank to equal result tensor rank";
2379
2380 auto inputDims = inputType.getShape();
2381 auto outputDims = outputType.getShape();
2382 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2383 int64_t dim = it.index();
2384 auto [inputDim, outputDim] = it.value();
2385 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2386 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2387 << " doesn't match dim(input, " << dim
2388 << ") = " << inputDim;
2389 }
2390 }
2391 return success();
2392}
2393
2394LogicalResult
2395tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2396 // Multiples must be constants.
2397 DenseIntElementsAttr multiplesAttr;
2398 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2399 return failure();
2400 multiples = llvm::to_vector(
2401 llvm::map_range(multiplesAttr.getValues<APInt>(),
2402 [](const APInt &val) { return val.getSExtValue(); }));
2403 return success();
2404}
2405
2406LogicalResult tosa::TileOp::inferReturnTypeComponents(
2407 MLIRContext *context, ::std::optional<Location> location,
2408 TileOp::Adaptor adaptor,
2409 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2410 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2411 SmallVector<int64_t> multiples;
2412 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2413 multiples)) {
2414 auto rank =
2415 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2416 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2417 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2418 return success();
2419 } else {
2420 multiples = convertToMlirShape(multiples);
2421 }
2422
2423 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2424 SmallVector<int64_t> outputShape;
2425 if (!inputShape.hasRank()) {
2426 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2427 inferredReturnShapes.push_back(
2428 ShapedTypeComponents(outputShape, inputType));
2429 return success();
2430 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2431 return failure();
2432
2433 // Any non dynamic dimension can be multiplied to a known size.
2434 outputShape.reserve(multiples.size());
2435 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2436 if (multiples[i] == ShapedType::kDynamic) {
2437 outputShape.push_back(ShapedType::kDynamic);
2438 } else {
2439 int64_t dim = inputShape.getDimSize(i);
2440 if (dim != ShapedType::kDynamic)
2441 dim *= multiples[i];
2442 outputShape.push_back(dim);
2443 }
2444 }
2445
2446 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2447 return success();
2448}
2449
2450LogicalResult tosa::TileOp::verify() {
2451 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2452 /* outType = */ getOutput().getType())
2453 .failed()) {
2454 return failure();
2455 }
2456 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2457 ShapedType outputType = llvm::cast<ShapedType>(getType());
2458
2459 shapeType multiplesType =
2460 llvm::cast<tosa::shapeType>(getMultiples().getType());
2461
2462 auto multiplesRank = multiplesType.getRank();
2463
2464 if (inputType.hasRank()) {
2465 if (inputType.getRank() != multiplesRank)
2466 return emitOpError("expect 'multiples' to have rank ")
2467 << inputType.getRank() << " but got " << multiplesRank << ".";
2468 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2469 return emitOpError("expect same input and output tensor rank.");
2470 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2471 return emitOpError("expect 'multiples' array to have length ")
2472 << outputType.getRank() << " but got " << multiplesRank << ".";
2473
2474 SmallVector<int64_t> multiples;
2475 if (getConstantMultiples(multiples).succeeded() &&
2476 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2477 return emitOpError(
2478 "expect element of 'multiples' to be positive integer or -1.");
2479
2480 return success();
2481}
2482
2483bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2484 if (l.size() != r.size() || l.size() != 1)
2485 return false;
2486 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2487}
2488
2489LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2490 MLIRContext *context, ::std::optional<Location> location,
2491 ReshapeOp::Adaptor adaptor,
2492 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2493 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2494 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2495 llvm::SmallVector<int64_t> newShapeValue;
2496 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2497 newShapeValue)) {
2498 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2499 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2500 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2501 return success();
2502 } else {
2503 newShapeValue = convertToMlirShape(newShapeValue);
2504 }
2505
2506 // We cannot infer from the total number of elements so we must take the
2507 // shape attribute as exact.
2508 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2509 inferredReturnShapes.push_back(
2510 ShapedTypeComponents(newShapeValue, inputType));
2511 return success();
2512 }
2513
2514 // Determine the number of elements covered by the slice of all static
2515 // dimensions. This allows us to infer the length of the remaining dynamic
2516 // dimension.
2517 int64_t numElements = inputShape.getNumElements();
2518 int64_t staticMul = 1;
2519 for (auto val : newShapeValue) {
2520 if (ShapedType::isStatic(val)) {
2521 staticMul *= val;
2522 }
2523 }
2524
2525 // Determine the length of the dynamic dimension.
2526 for (auto &val : newShapeValue) {
2527 if (ShapedType::isDynamic(val))
2528 val = numElements / staticMul;
2529 }
2530
2531 inferredReturnShapes.push_back(
2532 ShapedTypeComponents(newShapeValue, inputType));
2533 return success();
2534}
2535
2536llvm::LogicalResult tosa::ReshapeOp::verify() {
2537 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2538 /* outType = */ getOutput().getType())
2539 .failed()) {
2540 return failure();
2541 }
2542 TensorType inputType = getInput1().getType();
2543
2544 SmallVector<int64_t> shapeValues;
2545 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2546 // skip following checks if shape is not constant
2547 return mlir::success();
2548 }
2549
2550 int missingDims = llvm::count(shapeValues, -1);
2551 if (missingDims > 1)
2552 return emitOpError() << "expected at most one target dimension to be -1";
2553
2554 const auto outputType = dyn_cast<RankedTensorType>(getType());
2555 if (!outputType)
2556 return success();
2557
2558 if ((int64_t)shapeValues.size() != outputType.getRank())
2559 return emitOpError() << "new shape does not match result rank";
2560
2561 for (auto [newShapeDim, outputShapeDim] :
2562 zip(shapeValues, outputType.getShape())) {
2563 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2564 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2565 return emitOpError() << "new shape is inconsistent with result shape";
2566
2567 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2568 return emitOpError() << "new shape has invalid tensor dimension size "
2569 << newShapeDim;
2570 }
2571
2572 if (inputType.hasStaticShape()) {
2573 int64_t inputElementsNum = inputType.getNumElements();
2574 if (outputType.hasStaticShape()) {
2575 int64_t outputElementsNum = outputType.getNumElements();
2576 if (inputElementsNum != outputElementsNum) {
2577 return emitOpError() << "cannot reshape " << inputElementsNum
2578 << " elements into " << outputElementsNum;
2579 }
2580 }
2581
2582 int64_t newShapeElementsNum =
2583 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2584 return (dim > 0) ? acc * dim : acc;
2585 });
2586 bool isStaticNewShape =
2587 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2588 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2589 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2590 return emitOpError() << "cannot reshape " << inputElementsNum
2591 << " elements into " << newShapeElementsNum;
2592 }
2593 }
2594
2595 return mlir::success();
2596}
2597
2598// return failure if val is not a constant
2599// set zp to -1 if val is non-zero float or val is not integer nor float
2600// otherwise set zp to val's constant value
2601static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2602 ElementsAttr zpAttr;
2603 if (!matchPattern(val, m_Constant(&zpAttr))) {
2604 return failure();
2605 }
2606
2607 Type zpElemType = zpAttr.getElementType();
2608
2609 if (llvm::isa<FloatType>(zpElemType)) {
2610 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2611 return 0;
2612 }
2613 // return non-zero value to trigger error check
2614 return -1;
2615 }
2616
2617 if (llvm::isa<IntegerType>(zpElemType)) {
2618 if (signExtend)
2619 return zpAttr.getValues<APInt>()[0].getSExtValue();
2620 else
2621 return zpAttr.getValues<APInt>()[0].getZExtValue();
2622 }
2623
2624 // return non-zero value to trigger error check
2625 return -1;
2626}
2627
2628template <typename T>
2629static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2630 const std::string &operand) {
2631 Type zpElemType = getElementTypeOrSelf(val);
2632
2633 if (!zpElemType.isInteger(8) && zp != 0) {
2634 // convert operand to lower case for error message
2635 std::string lower = operand;
2636 llvm::transform(lower, lower.begin(), ::tolower);
2637 return op.emitOpError()
2638 << lower << " zero point must be zero for non-int8 integer types";
2639 }
2640
2641 return success();
2642}
2643
2644static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2645 const int64_t &zp,
2646 const std::string &operand) {
2647 bool isInputZp = (operand == "Input");
2648
2649 bool tensorUnsigned =
2650 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2651 StringRef tensorName = isInputZp ? "input" : "output";
2652
2653 Type zpElemType = getElementTypeOrSelf(zpVal);
2654
2655 if (zp != 0) {
2656 if (!zpElemType.isInteger(8) &&
2657 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2658 return op.emitOpError()
2659 << "expect " << tensorName << "_zp of 0, got " << zp;
2660 }
2661 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2662 return op.emitOpError() << "expect " << tensorName
2663 << "_zp of 0 or 32768 for unsigned int16 "
2664 << tensorName << ", got " << zp;
2665 }
2666 }
2667
2668 return success();
2669}
2670
2671#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2672 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2673 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2674 } \
2675 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2676 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2677 }
2678
2679ZERO_POINT_HELPER(Conv2DOp, Input, true)
2680ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2681ZERO_POINT_HELPER(Conv3DOp, Input, true)
2682ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2683ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2684ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2685ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2686ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2687ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2688ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2689ZERO_POINT_HELPER(MatMulOp, A, true)
2690ZERO_POINT_HELPER(MatMulOp, B, true)
2691ZERO_POINT_HELPER(NegateOp, Input1, true)
2692ZERO_POINT_HELPER(NegateOp, Output, true)
2693ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2694ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2695#undef ZERO_POINT_HELPER
2696
2697LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2698 MLIRContext *context, ::std::optional<Location> location,
2699 TransposeOp::Adaptor adaptor,
2700 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2701 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2702
2703 // If input rank and permutation length is unknown, the output rank is
2704 // unknown.
2705 if (!inputShape.hasRank()) {
2706 inferredReturnShapes.push_back(ShapedTypeComponents());
2707 return success();
2708 }
2709
2710 const auto inputRank = inputShape.getRank();
2711
2712 // This would imply the number of permutations does not match the rank of
2713 // the input which is illegal.
2714 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2715 return failure();
2716 }
2717
2718 SmallVector<int64_t> outputShape;
2719 // Rank-0 means no permutations matter.
2720 if (inputRank == 0) {
2721 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2722 return success();
2723 }
2724
2725 // Check whether the input dimensions are all the same.
2726 bool allTheSame = true;
2727 for (int i = 1, s = inputRank; i < s; i++) {
2728 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2729 allTheSame = false;
2730 break;
2731 }
2732 }
2733
2734 // If all of the input dimensions are the same we don't care about the
2735 // permutation.
2736 if (allTheSame) {
2737 outputShape.resize(inputRank, inputShape.getDimSize(0));
2738 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2739 return success();
2740 }
2741
2742 outputShape.resize(inputRank, ShapedType::kDynamic);
2743
2744 // Constant permutation values must be within the input rank.
2745 if (llvm::any_of(adaptor.getPerms(),
2746 [inputRank](const auto i) { return i >= inputRank; }))
2747 return failure();
2748
2749 outputShape.reserve(inputRank);
2750 for (int i = 0, s = inputRank; i < s; i++) {
2751 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2752 }
2753
2754 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2755 return success();
2756}
2757
2758LogicalResult tosa::TransposeOp::verify() {
2759 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2760 /* outType = */ getOutput().getType())
2761 .failed()) {
2762 return failure();
2763 }
2764
2765 const ShapeAdaptor inputShape(getInput1().getType());
2766 const ShapeAdaptor outputShape(getOutput().getType());
2767
2768 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2769
2770 if (inputShape.hasRank() &&
2771 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2772 return emitOpError() << "expected perms attribute to have size "
2773 << inputShape.getRank()
2774 << " (input rank) but got size "
2775 << constantPerms.size();
2776
2777 if (inputShape.hasRank() && outputShape.hasRank() &&
2778 inputShape.getRank() != outputShape.getRank())
2779 return emitOpError()
2780 << "expected input tensor rank to equal result tensor rank";
2781
2782 if (outputShape.hasRank() &&
2783 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2784 return emitOpError() << "expected perms attribute to have size "
2785 << outputShape.getRank()
2786 << " (output rank) but got size "
2787 << constantPerms.size();
2788
2789 if (!llvm::all_of(constantPerms,
2790 [&constantPerms](int32_t s) {
2791 return s >= 0 &&
2792 static_cast<size_t>(s) < constantPerms.size();
2793 }) ||
2794 !isPermutationVector(llvm::to_vector(llvm::map_range(
2795 constantPerms, [](int32_t v) -> int64_t { return v; }))))
2796 return emitOpError() << "expected valid permutation indices";
2797
2798 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2799 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2800 inputShape.getNumElements() != outputShape.getNumElements())
2801 return emitOpError() << "expected input1 and output to have same numbers "
2802 "of elements, got "
2803 << inputShape.getNumElements() << " and "
2804 << outputShape.getNumElements();
2805
2806 // Verify that the types of the input and output tensors are properly
2807 // permuted.
2808 if (inputShape.hasRank() && outputShape.hasRank()) {
2809 for (auto i = 0; i < outputShape.getRank(); i++) {
2810 if (inputShape.isDynamicDim(constantPerms[i]) ||
2811 outputShape.isDynamicDim(i))
2812 continue;
2813
2814 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2815 return emitOpError()
2816 << "expected output tensor dim " << i << " to match "
2817 << "input dim " << constantPerms[i] << " with value of "
2818 << inputShape.getDimSize(constantPerms[i]);
2819 }
2820 }
2821
2822 return success();
2823}
2824
2825LogicalResult TransposeOp::reifyResultShapes(
2826 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2827
2828 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2829
2830 Value input = getInput1();
2831 auto inputType = cast<TensorType>(input.getType());
2832
2833 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2834 for (auto dim : transposePerms) {
2835 int32_t dimInInput = transposePerms[dim];
2836 if (inputType.isDynamicDim(dimInInput))
2837 returnedDims[dim] =
2838 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2839 .getResult();
2840 else
2841 returnedDims[dim] =
2842 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2843 }
2844
2845 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2846 return success();
2847}
2848
2849LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2850 MLIRContext *context, ::std::optional<Location> location,
2851 GatherOp::Adaptor adaptor,
2852 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2853 llvm::SmallVector<int64_t> outputShape;
2854 outputShape.resize(3, ShapedType::kDynamic);
2855
2856 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2857 if (valuesShape.hasRank()) {
2858 outputShape[0] = valuesShape.getDimSize(0);
2859 outputShape[2] = valuesShape.getDimSize(2);
2860 }
2861
2862 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2863 if (indicesShape.hasRank()) {
2864 if (outputShape[0] == ShapedType::kDynamic)
2865 outputShape[0] = indicesShape.getDimSize(0);
2866 if (outputShape[1] == ShapedType::kDynamic)
2867 outputShape[1] = indicesShape.getDimSize(1);
2868 }
2869
2870 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2871 return success();
2872}
2873
2874LogicalResult tosa::GatherOp::verify() {
2875 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2876 /* outType = */ getOutput().getType())
2877 .failed()) {
2878 return failure();
2879 }
2880
2881 const ShapeAdaptor valuesShape(getValues().getType());
2882 const ShapeAdaptor indicesShape(getIndices().getType());
2883 const ShapeAdaptor outputShape(getOutput().getType());
2884
2885 int64_t N = ShapedType::kDynamic;
2886 int64_t W = ShapedType::kDynamic;
2887 int64_t C = ShapedType::kDynamic;
2888
2889 if (valuesShape.hasRank()) {
2890 N = valuesShape.getDimSize(0);
2891 C = valuesShape.getDimSize(2);
2892 }
2893 if (indicesShape.hasRank()) {
2894 const int64_t indicesN = indicesShape.getDimSize(0);
2895 W = indicesShape.getDimSize(1);
2896 if (N == ShapedType::kDynamic)
2897 N = indicesN;
2898 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2899 return emitOpError() << "requires indices dimension 0 to have size " << N
2900 << ", got " << indicesN;
2901 }
2902 if (outputShape.hasRank()) {
2903 const int64_t outputN = outputShape.getDimSize(0);
2904 const int64_t outputW = outputShape.getDimSize(1);
2905 const int64_t outputC = outputShape.getDimSize(2);
2906 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2907 N != outputN)
2908 return emitOpError() << "requires output dimension 0 to have size " << N
2909 << ", got " << outputN;
2910
2911 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2912 W != outputW)
2913 return emitOpError() << "requires output dimension 1 to have size " << W
2914 << ", got " << outputW;
2915 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2916 C != outputC)
2917 return emitOpError() << "requires output dimension 2 to have size " << C
2918 << ", got " << outputC;
2919 }
2920 return success();
2921}
2922
2923LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2924 MLIRContext *context, ::std::optional<Location> location,
2925 ResizeOp::Adaptor adaptor,
2926 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2927 llvm::SmallVector<int64_t, 4> outputShape;
2928 outputShape.resize(4, ShapedType::kDynamic);
2929
2930 ShapeAdaptor inputShape(adaptor.getInput().getType());
2931 if (!inputShape.hasRank())
2932 return failure();
2933
2934 outputShape[0] = inputShape.getDimSize(0);
2935 outputShape[3] = inputShape.getDimSize(3);
2936 int64_t inputHeight = inputShape.getDimSize(1);
2937 int64_t inputWidth = inputShape.getDimSize(2);
2938
2939 if ((inputHeight == ShapedType::kDynamic) ||
2940 (inputWidth == ShapedType::kDynamic))
2941 return failure();
2942
2943 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2944 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2945 scaleInt) ||
2946 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2947 offsetInt) ||
2948 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2949 borderInt)) {
2950 return failure();
2951 }
2952
2953 // Compute the output shape based on attributes: scale, offset, and border.
2954 const int64_t outputHeight =
2955 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2956 scaleInt[1]) +
2957 1;
2958
2959 const int64_t outputWidth =
2960 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2961 scaleInt[3]) +
2962 1;
2963
2964 if (outputHeight < 0 || outputWidth < 0) {
2965 return emitOptionalError(
2966 location,
2967 "calculated output height and width must be non-negative, "
2968 "got height = ",
2969 outputHeight, ", width = ", outputWidth);
2970 }
2971
2972 outputShape[1] = outputHeight;
2973 outputShape[2] = outputWidth;
2974 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2975 return success();
2976}
2977
2978LogicalResult tosa::ResizeOp::verify() {
2979 const Value input = getInput();
2980 const Value output = getOutput();
2981 const RankedTensorType inputType =
2982 llvm::dyn_cast<RankedTensorType>(input.getType());
2983 const RankedTensorType outputType =
2984 llvm::dyn_cast<RankedTensorType>(output.getType());
2985
2986 SmallVector<int64_t> scaleValues;
2987 SmallVector<int64_t> offsetValues;
2988 SmallVector<int64_t> borderValues;
2989 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2990 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2991 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2992 // Skip following checks if shape is not constant
2993 return success();
2994 }
2995
2996 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2997 return emitOpError("expect all scale values to be > 0, got ")
2998 << scaleValues;
2999
3000 const int64_t scaleYN = scaleValues[0];
3001 const int64_t scaleYD = scaleValues[1];
3002 const int64_t scaleXN = scaleValues[2];
3003 const int64_t scaleXD = scaleValues[3];
3004
3005 const int64_t offsetY = offsetValues[0];
3006 const int64_t offsetX = offsetValues[1];
3007
3008 const int64_t borderY = borderValues[0];
3009 const int64_t borderX = borderValues[1];
3010
3011 if (!inputType)
3012 return success();
3013 if (!outputType)
3014 return success();
3015
3016 const int64_t oh = outputType.getDimSize(1);
3017 const int64_t ow = outputType.getDimSize(2);
3018 const int64_t ih = inputType.getDimSize(1);
3019 const int64_t iw = inputType.getDimSize(2);
3020
3021 // Don't check with input height that could be broadcast (ih != 1)
3022 // since Linalg, a consumer of TOSA, expects broadcasting support
3023 // in resize to be available. Taking the cautious approach for now,
3024 // we can consider removing support for broadcasting later.
3025 if (ih != ShapedType::kDynamic && ih != 1) {
3026 const std::optional<int64_t> calculatedOutHeightMinusOne =
3027 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3028 if (!calculatedOutHeightMinusOne.has_value())
3029 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3030 "border_y ")
3031 << "to be wholly divisible by scale_y_d, got ((" << ih
3032 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3033 << ") / " << scaleYD;
3034 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3035 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3036 return emitOpError("calculated output height did not match expected: ")
3037 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3038 }
3039
3040 // Don't check with input width that could be broadcast (iw != 1)
3041 // since Linalg, a consumer of TOSA, expects broadcasting support
3042 // in resize to be available. Taking the cautious approach for now,
3043 // we can consider removing support for broadcasting later.
3044 if (iw != ShapedType::kDynamic && iw != 1) {
3045 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3046 const std::optional<int64_t> calculatedOutWidthMinusOne =
3047 idivCheck(scaledInWidth, scaleXD);
3048 if (!calculatedOutWidthMinusOne.has_value())
3049 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3050 "border_x ")
3051 << "to be wholly divisible by scale_x_d, got ((" << iw
3052 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3053 << ") / " << scaleXD;
3054 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3055 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3056 return emitOpError("calculated output width did not match expected: ")
3057 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3058 }
3059
3060 return success();
3061}
3062
3063LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3064 MLIRContext *context, ::std::optional<Location> location,
3065 ScatterOp::Adaptor adaptor,
3066 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3067 llvm::SmallVector<int64_t> outputShape;
3068 outputShape.resize(3, ShapedType::kDynamic);
3069
3070 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3071 if (valuesInShape.hasRank()) {
3072 outputShape[0] = valuesInShape.getDimSize(0);
3073 outputShape[1] = valuesInShape.getDimSize(1);
3074 outputShape[2] = valuesInShape.getDimSize(2);
3075 }
3076
3077 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3078 if (indicesShape.hasRank()) {
3079 if (outputShape[0] == ShapedType::kDynamic)
3080 outputShape[0] = indicesShape.getDimSize(0);
3081 }
3082
3083 ShapeAdaptor inputShape(adaptor.getInput().getType());
3084 if (inputShape.hasRank()) {
3085 if (outputShape[0] == ShapedType::kDynamic)
3086 outputShape[0] = inputShape.getDimSize(0);
3087 if (outputShape[2] == ShapedType::kDynamic)
3088 outputShape[2] = inputShape.getDimSize(2);
3089 }
3090
3091 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3092 return success();
3093}
3094
3095LogicalResult tosa::ScatterOp::verify() {
3096 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3097 /* outType = */ getValuesOut().getType())
3098 .failed() ||
3099 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3100 /* outType = */ getValuesOut().getType())
3101 .failed()) {
3102 return failure();
3103 }
3104
3105 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3106 const ShapeAdaptor indicesShape(getIndices().getType());
3107 const ShapeAdaptor inputShape(getInput().getType());
3108 const ShapeAdaptor outputShape(getValuesOut().getType());
3109
3110 int64_t N = ShapedType::kDynamic;
3111 int64_t K = ShapedType::kDynamic;
3112 int64_t W = ShapedType::kDynamic;
3113 int64_t C = ShapedType::kDynamic;
3114 if (valuesInShape.hasRank()) {
3115 N = valuesInShape.getDimSize(0);
3116 K = valuesInShape.getDimSize(1);
3117 C = valuesInShape.getDimSize(2);
3118 }
3119 if (indicesShape.hasRank()) {
3120 const int64_t indicesN = indicesShape.getDimSize(0);
3121 W = indicesShape.getDimSize(1);
3122 if (N == ShapedType::kDynamic)
3123 N = indicesN;
3124 else if (indicesN != ShapedType::kDynamic && N != indicesN)
3125 return emitOpError() << "requires indices dimension 0 to have size " << N
3126 << ", got " << indicesN;
3127 }
3128 if (inputShape.hasRank()) {
3129 const int64_t inputN = inputShape.getDimSize(0);
3130 const int64_t inputW = inputShape.getDimSize(1);
3131 const int64_t inputC = inputShape.getDimSize(2);
3132 if (N == ShapedType::kDynamic)
3133 N = inputN;
3134 else if (inputN != ShapedType::kDynamic && N != inputN)
3135 return emitOpError() << "requires input dimension 0 to have size " << N
3136 << ", got " << inputN;
3137 if (W == ShapedType::kDynamic)
3138 W = inputW;
3139 else if (inputW != ShapedType::kDynamic && W != inputW)
3140 return emitOpError() << "requires input dimension 1 to have size " << W
3141 << ", got " << inputW;
3142
3143 if (C == ShapedType::kDynamic)
3144 C = inputC;
3145 else if (inputC != ShapedType::kDynamic && C != inputC)
3146 return emitOpError() << "requires input dimension 2 to have size " << C
3147 << ", got " << inputC;
3148 }
3149 if (outputShape.hasRank()) {
3150 const int64_t outputN = outputShape.getDimSize(0);
3151 const int64_t outputK = outputShape.getDimSize(1);
3152 const int64_t outputC = outputShape.getDimSize(2);
3153 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3154 N != outputN)
3155 return emitOpError() << "requires values_out dimension 0 to have size "
3156 << N << ", got " << outputN;
3157 if (K == ShapedType::kDynamic)
3158 K = outputK;
3159 else if (outputK != ShapedType::kDynamic && K != outputK)
3160 return emitOpError() << "requires values_out dimension 1 to have size "
3161 << K << ", got " << outputK;
3162 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3163 C != outputC)
3164 return emitOpError() << "requires values_out dimension 2 to have size "
3165 << C << ", got " << outputC;
3166 }
3167 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
3168 return emitOpError() << "requires dimensions K >= W, got K=" << K
3169 << " and W=" << W;
3170
3171 return success();
3172}
3173
3174static LogicalResult ReduceInferReturnTypes(
3175 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3176 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3177 int64_t axisVal = axis.getValue().getSExtValue();
3178 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3179 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3180 return success();
3181 }
3182
3183 SmallVector<int64_t> outputShape;
3184 operandShape.getDims(outputShape);
3185 outputShape[axisVal] = 1;
3186 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3187 return success();
3188}
3189
3190#define COMPATIBLE_RETURN_TYPES(OP) \
3191 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3192 if (l.size() != r.size() || l.size() != 1) \
3193 return false; \
3194 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3195 return false; \
3196 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3197 }
3198
3199#define REDUCE_SHAPE_INFER(OP) \
3200 LogicalResult OP::inferReturnTypeComponents( \
3201 MLIRContext *context, ::std::optional<Location> location, \
3202 OP::Adaptor adaptor, \
3203 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3204 Type inputType = \
3205 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3206 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3207 const Properties &prop = adaptor.getProperties(); \
3208 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3209 inferredReturnShapes); \
3210 } \
3211 COMPATIBLE_RETURN_TYPES(OP)
3212
3213REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3214REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3215REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3216REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3217REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3218REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3219#undef REDUCE_SHAPE_INFER
3220COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3221#undef COMPATIBLE_RETURN_TYPES
3222
3223template <typename T>
3224static LogicalResult verifyReduceOp(T op) {
3225 // All TOSA reduce Ops have input, output and axis.
3226 TensorType inputType = op.getInput().getType();
3227 TensorType outputType = op.getOutput().getType();
3228 int32_t reduceAxis = op.getAxis();
3229
3230 if (reduceAxis < 0) {
3231 op.emitOpError("reduce axis must not be negative");
3232 return failure();
3233 }
3234 if (inputType.hasRank()) {
3235 int64_t inputRank = inputType.getRank();
3236 // We allow for a special case where the input/output shape has rank 0 and
3237 // axis is also 0.
3238 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3239 op.emitOpError("expect input tensor rank (")
3240 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3241 << ")";
3242 return failure();
3243 }
3244 }
3245 if (outputType.hasRank()) {
3246 int64_t outputRank = outputType.getRank();
3247 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3248 op.emitOpError(
3249 "expect output tensor rank to be equal to input tensor rank");
3250 return failure();
3251 }
3252 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3253 op.emitOpError("expect output tensor rank (")
3254 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3255 << ")";
3256 return failure();
3257 }
3258 // We can only verify the reduced dimension size to be 1 if this is not
3259 // the special case of output rank == 0.
3260 if (outputRank != 0) {
3261 auto outputShape = outputType.getShape();
3262 if (!outputType.isDynamicDim(reduceAxis) &&
3263 outputShape[reduceAxis] != 1) {
3264 op.emitOpError("expect reduced dimension size to be 1, got ")
3265 << outputShape[reduceAxis];
3266 return failure();
3267 }
3268 }
3269 }
3270 return success();
3271}
3272
3273LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3274LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3275LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3276LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3277LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3278LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3279
3280static LogicalResult NAryInferReturnTypes(
3281 const ValueShapeRange &operands,
3282 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3284 if (resolveBroadcastShape(operands, outShape).failed()) {
3285 inferredReturnShapes.push_back(ShapedTypeComponents());
3286 } else {
3287 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3288 }
3289 return success();
3290}
3291
3292#define NARY_SHAPE_INFER(OP) \
3293 LogicalResult OP::inferReturnTypeComponents( \
3294 MLIRContext *context, ::std::optional<Location> location, \
3295 ValueShapeRange operands, DictionaryAttr attributes, \
3296 OpaqueProperties properties, RegionRange regions, \
3297 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3298 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3299 }
3300
3301NARY_SHAPE_INFER(tosa::AbsOp)
3302NARY_SHAPE_INFER(tosa::AddOp)
3303NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3304NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3305NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3306NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3307NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3308NARY_SHAPE_INFER(tosa::CastOp)
3309NARY_SHAPE_INFER(tosa::CeilOp)
3310NARY_SHAPE_INFER(tosa::ClampOp)
3311NARY_SHAPE_INFER(tosa::ClzOp)
3312NARY_SHAPE_INFER(tosa::CosOp)
3313NARY_SHAPE_INFER(tosa::ExpOp)
3314NARY_SHAPE_INFER(tosa::FloorOp)
3315NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3316NARY_SHAPE_INFER(tosa::GreaterOp)
3317NARY_SHAPE_INFER(tosa::IdentityOp)
3318NARY_SHAPE_INFER(tosa::IntDivOp)
3319NARY_SHAPE_INFER(tosa::LogOp)
3320NARY_SHAPE_INFER(tosa::LogicalAndOp)
3321NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3322NARY_SHAPE_INFER(tosa::LogicalNotOp)
3323NARY_SHAPE_INFER(tosa::LogicalOrOp)
3324NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3325NARY_SHAPE_INFER(tosa::LogicalXorOp)
3326NARY_SHAPE_INFER(tosa::MaximumOp)
3327NARY_SHAPE_INFER(tosa::MinimumOp)
3328NARY_SHAPE_INFER(tosa::PowOp)
3329NARY_SHAPE_INFER(tosa::ReciprocalOp)
3330NARY_SHAPE_INFER(tosa::ReverseOp)
3331NARY_SHAPE_INFER(tosa::RsqrtOp)
3332NARY_SHAPE_INFER(tosa::SinOp)
3333NARY_SHAPE_INFER(tosa::SelectOp)
3334NARY_SHAPE_INFER(tosa::SubOp)
3335NARY_SHAPE_INFER(tosa::TanhOp)
3336NARY_SHAPE_INFER(tosa::ErfOp)
3337NARY_SHAPE_INFER(tosa::SigmoidOp)
3338#undef PRED_SHAPE_INFER
3339
3340LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3341 MLIRContext *context, ::std::optional<Location> location,
3342 NegateOp::Adaptor adaptor,
3343 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3344 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3345 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3346 return success();
3347}
3348
3349LogicalResult tosa::NegateOp::verify() {
3350 // Verify same element type
3351 const Type input1Type = getInput1().getType();
3352 const Type outputType = getOutput().getType();
3353 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3354 return failure();
3355
3356 // Verify same shape
3357 const SmallVector<Type, 2> types = {input1Type, outputType};
3358 if (failed(verifyCompatibleShapes(types)))
3359 return emitOpError() << "requires the same shape for input1 and output";
3360
3361 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3362 const Type input1ZpEType =
3363 getStorageElementTypeOrSelf(getInput1Zp().getType());
3364 if (input1EType != input1ZpEType) {
3365 return emitOpError("expect both input1 and its zero point are the same "
3366 "element type, got ")
3367 << input1EType << " and " << input1ZpEType;
3368 }
3369 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3370 const Type outputZpEType =
3371 getStorageElementTypeOrSelf(getOutputZp().getType());
3372 if (outputEType != outputZpEType) {
3373 return emitOpError("expect both output and its zero point are the same "
3374 "element type, got ")
3375 << outputEType << " and " << outputZpEType;
3376 }
3377
3378 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3379 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3380 return failure();
3381
3382 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3383 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3384 return failure();
3385
3386 return success();
3387}
3388
3389static LogicalResult poolingInferReturnTypes(
3390 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3392 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3393 llvm::SmallVector<int64_t> outputShape;
3394 outputShape.resize(4, ShapedType::kDynamic);
3395
3396 // We only know the rank if the input type is unranked.
3397 if (!inputShape) {
3398 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3399 return success();
3400 }
3401
3402 // Batch and number of channels are identical for pooling layer.
3403 outputShape[0] = inputShape.getDimSize(0);
3404 outputShape[3] = inputShape.getDimSize(3);
3405
3406 int64_t height = inputShape.getDimSize(1);
3407 int64_t width = inputShape.getDimSize(2);
3408
3409 if (ShapedType::isStatic(height)) {
3410 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3411 outputShape[1] = padded / stride[0] + 1;
3412 }
3413
3414 if (ShapedType::isStatic(width)) {
3415 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3416 outputShape[2] = padded / stride[1] + 1;
3417 }
3418
3419 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3420 return success();
3421}
3422
3423LogicalResult Conv2DOp::inferReturnTypeComponents(
3424 MLIRContext *context, ::std::optional<Location> location,
3425 Conv2DOp::Adaptor adaptor,
3426 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3427 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3428
3429 int64_t inputWidth = ShapedType::kDynamic;
3430 int64_t inputHeight = ShapedType::kDynamic;
3431 int64_t weightWidth = ShapedType::kDynamic;
3432 int64_t weightHeight = ShapedType::kDynamic;
3433
3434 // Input shape describes input width/height and batch.
3435
3436 ShapeAdaptor inputShape(adaptor.getInput().getType());
3437 if (inputShape.hasRank()) {
3438 outputShape[0] = inputShape.getDimSize(0);
3439 inputHeight = inputShape.getDimSize(1);
3440 inputWidth = inputShape.getDimSize(2);
3441 }
3442
3443 // Weight shapes describes the filter width/height and the output channels.
3444 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3445 if (weightShape.hasRank()) {
3446 outputShape[3] = weightShape.getDimSize(0);
3447 weightHeight = weightShape.getDimSize(1);
3448 weightWidth = weightShape.getDimSize(2);
3449 }
3450
3451 // Bias shape can describe the output channels.
3452 ShapeAdaptor biasShape(adaptor.getBias().getType());
3453 if (biasShape.hasRank()) {
3454 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3455 ? biasShape.getDimSize(0)
3456 : outputShape[3];
3457 }
3458
3459 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3460 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3461 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3462
3463 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3464 int64_t inputSize = inputHeight + padding[0] + padding[1];
3465 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3466 int64_t unstridedResult = inputSize - filterSize + 1;
3467 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3468 }
3469
3470 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3471 int64_t inputSize = inputWidth + padding[2] + padding[3];
3472 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3473 int64_t unstridedResult = inputSize - filterSize + 1;
3474 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3475 }
3476
3477 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3478 return success();
3479}
3480
3481LogicalResult Conv2DOp::verify() {
3482 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3483 verifyConvOpErrorIf(*this).failed())
3484 return failure();
3485 return success();
3486}
3487
3488LogicalResult Conv3DOp::inferReturnTypeComponents(
3489 MLIRContext *context, ::std::optional<Location> location,
3490 Conv3DOp::Adaptor adaptor,
3491 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3492 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3493
3494 int64_t inputWidth = ShapedType::kDynamic;
3495 int64_t inputHeight = ShapedType::kDynamic;
3496 int64_t inputDepth = ShapedType::kDynamic;
3497
3498 int64_t weightWidth = ShapedType::kDynamic;
3499 int64_t weightHeight = ShapedType::kDynamic;
3500 int64_t weightDepth = ShapedType::kDynamic;
3501
3502 // Input shape describes input width/height and batch.
3503 ShapeAdaptor inputShape(adaptor.getInput().getType());
3504 if (inputShape.hasRank()) {
3505 outputShape[0] = inputShape.getDimSize(0);
3506 inputDepth = inputShape.getDimSize(1);
3507 inputHeight = inputShape.getDimSize(2);
3508 inputWidth = inputShape.getDimSize(3);
3509 }
3510
3511 // Weight shapes describes the filter width/height and the output channels.
3512 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3513 if (weightShape.hasRank()) {
3514 outputShape[4] = weightShape.getDimSize(0);
3515 weightDepth = weightShape.getDimSize(1);
3516 weightHeight = weightShape.getDimSize(2);
3517 weightWidth = weightShape.getDimSize(3);
3518 }
3519
3520 // Bias shape can describe the output channels.
3521 ShapeAdaptor biasShape(adaptor.getBias().getType());
3522 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3523 outputShape[4] = biasShape.getDimSize(0);
3524 }
3525
3526 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3527 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3528 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3529
3530 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3531 int32_t inputSize = inputDepth + pad[0] + pad[1];
3532 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3533 int32_t unstridedResult = inputSize - filterSize + 1;
3534 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3535 }
3536
3537 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3538 int32_t inputSize = inputHeight + pad[2] + pad[3];
3539 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3540 int32_t unstridedResult = inputSize - filterSize + 1;
3541 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3542 }
3543
3544 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3545 int32_t inputSize = inputWidth + pad[4] + pad[5];
3546 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3547 int32_t unstridedResult = inputSize - filterSize + 1;
3548 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3549 }
3550
3551 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3552 return success();
3553}
3554
3555LogicalResult Conv3DOp::verify() {
3556 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3557 verifyConvOpErrorIf(*this).failed())
3558 return failure();
3559 return success();
3560}
3561
3562LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3563 MLIRContext *context, ::std::optional<Location> location,
3564 AvgPool2dOp::Adaptor adaptor,
3565 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3566 ShapeAdaptor inputShape(adaptor.getInput().getType());
3567 const Properties &prop = adaptor.getProperties();
3568 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3569 inferredReturnShapes);
3570}
3571
3572LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3573 MLIRContext *context, ::std::optional<Location> location,
3574 MaxPool2dOp::Adaptor adaptor,
3575 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3576 ShapeAdaptor inputShape(adaptor.getInput().getType());
3577 const Properties &prop = adaptor.getProperties();
3578 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3579 inferredReturnShapes);
3580}
3581
3582LogicalResult MaxPool2dOp::verify() {
3583 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3584 /* outType = */ getOutput().getType())))
3585 return failure();
3586
3587 if (failed(verifyPoolingOp(*this)))
3588 return failure();
3589
3590 return success();
3591}
3592
3593LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3594 MLIRContext *context, ::std::optional<Location> location,
3595 DepthwiseConv2DOp::Adaptor adaptor,
3596 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3597 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3598
3599 int64_t inputWidth = ShapedType::kDynamic;
3600 int64_t inputHeight = ShapedType::kDynamic;
3601 int64_t inputChannels = ShapedType::kDynamic;
3602
3603 int64_t weightWidth = ShapedType::kDynamic;
3604 int64_t weightHeight = ShapedType::kDynamic;
3605 int64_t depthChannels = ShapedType::kDynamic;
3606
3607 // Input shape describes input width/height and batch.
3608 ShapeAdaptor inputShape(adaptor.getInput().getType());
3609 if (inputShape.hasRank()) {
3610 outputShape[0] = inputShape.getDimSize(0);
3611 inputHeight = inputShape.getDimSize(1);
3612 inputWidth = inputShape.getDimSize(2);
3613 inputChannels = inputShape.getDimSize(3);
3614 }
3615
3616 // Weight shapes describes the filter width/height and the output channels.
3617 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3618 if (weightShape.hasRank()) {
3619 weightHeight = weightShape.getDimSize(0);
3620 weightWidth = weightShape.getDimSize(1);
3621 inputChannels = ShapedType::isDynamic(inputChannels)
3622 ? weightShape.getDimSize(2)
3623 : inputChannels;
3624 depthChannels = weightShape.getDimSize(3);
3625 }
3626
3627 // If both inputChannels and depthChannels are available we can determine
3628 // the output channels.
3629 if (ShapedType::isStatic(inputChannels) &&
3630 ShapedType::isStatic(depthChannels)) {
3631 outputShape[3] = inputChannels * depthChannels;
3632 }
3633
3634 // Bias shape can describe the output channels.
3635 ShapeAdaptor biasShape(adaptor.getBias().getType());
3636 if (biasShape.hasRank()) {
3637 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3638 ? biasShape.getDimSize(0)
3639 : outputShape[3];
3640 }
3641
3642 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3643 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3644 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3645
3646 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3647 int64_t inputSize = inputHeight + padding[0] + padding[1];
3648 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3649 int64_t unstridedResult = inputSize - filterSize + 1;
3650 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3651 }
3652
3653 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3654 int64_t inputSize = inputWidth + padding[2] + padding[3];
3655 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3656 int64_t unstridedResult = inputSize - filterSize + 1;
3657 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3658 }
3659
3660 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3661 return success();
3662}
3663
3664LogicalResult DepthwiseConv2DOp::verify() {
3665 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3666 verifyConvOpErrorIf(*this).failed())
3667 return failure();
3668 return success();
3669}
3670
3671LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3672 MLIRContext *context, ::std::optional<Location> location,
3673 TransposeConv2DOp::Adaptor adaptor,
3674 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3675 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3676
3677 int64_t inputWidth = ShapedType::kDynamic;
3678 int64_t inputHeight = ShapedType::kDynamic;
3679 int64_t weightWidth = ShapedType::kDynamic;
3680 int64_t weightHeight = ShapedType::kDynamic;
3681
3682 // Input shape describes input width/height and batch.
3683 ShapeAdaptor inputShape(adaptor.getInput().getType());
3684 if (inputShape.hasRank()) {
3685 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3686 ? inputShape.getDimSize(0)
3687 : outputShape[0];
3688 inputHeight = inputShape.getDimSize(1);
3689 inputWidth = inputShape.getDimSize(2);
3690 }
3691
3692 // Weight shapes describes the filter width/height and the output channels.
3693 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3694 if (weightShape.hasRank()) {
3695 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3696 ? weightShape.getDimSize(0)
3697 : outputShape[3];
3698 weightHeight = weightShape.getDimSize(1);
3699 weightWidth = weightShape.getDimSize(2);
3700 }
3701
3702 // Bias shape can describe the output channels.
3703 ShapeAdaptor biasShape(adaptor.getInput().getType());
3704 if (biasShape.hasRank()) {
3705 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3706 ? biasShape.getDimSize(0)
3707 : outputShape[3];
3708 }
3709
3710 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3711 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3712
3713 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3714 int64_t calculateSize =
3715 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3716 outputShape[1] =
3717 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3718 }
3719
3720 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3721 int64_t calculateSize =
3722 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3723 outputShape[2] =
3724 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3725 }
3726
3727 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3728 return success();
3729}
3730
3731LogicalResult TransposeConv2DOp::verify() {
3732 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3733 return failure();
3734
3735 const llvm::ArrayRef<int64_t> strides = getStride();
3736 const int64_t strideY = strides[0];
3737 const int64_t strideX = strides[1];
3738
3739 if (strideY < 1 || strideX < 1)
3740 return emitOpError("expect all stride values to be >= 1, got [")
3741 << strides << "]";
3742
3743 const auto checkPadAgainstKernelDim =
3744 [this](int64_t pad_value, int64_t kernel_dim_size,
3745 llvm::StringRef pad_name,
3746 llvm::StringRef kernel_dim_name) -> LogicalResult {
3747 if (pad_value <= -kernel_dim_size)
3748 return emitOpError("expected ")
3749 << pad_name << " > -" << kernel_dim_name
3750 << ", but got: " << pad_name << "=" << pad_value << " and "
3751 << kernel_dim_name << "=" << kernel_dim_size;
3752 return success();
3753 };
3754
3755 const llvm::ArrayRef<int64_t> padding = getOutPad();
3756 const int64_t outPadTop = padding[0];
3757 const int64_t outPadBottom = padding[1];
3758 const int64_t outPadLeft = padding[2];
3759 const int64_t outPadRight = padding[3];
3760
3761 const auto weightType =
3762 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3763
3764 if (weightType) {
3765 const int64_t kernelHeight = weightType.getDimSize(1);
3766 if (ShapedType::isStatic(kernelHeight)) {
3767 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3768 "out_pad_top", "KH")))
3769 return failure();
3770
3771 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3772 "out_pad_bottom", "KH")))
3773 return failure();
3774 }
3775
3776 const int64_t kernelWidth = weightType.getDimSize(2);
3777 if (ShapedType::isStatic(kernelWidth)) {
3778 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3779 "out_pad_left", "KW")))
3780 return failure();
3781
3782 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3783 "out_pad_right", "KW")))
3784 return failure();
3785 }
3786 }
3787
3788 // Rest of the checks depend on the output type being a RankedTensorType
3789 const auto outputType =
3790 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3791 if (!outputType)
3792 return success();
3793
3794 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3795 if (inputType && weightType) {
3796 const int64_t inputHeight = inputType.getDimSize(1);
3797 const int64_t kernelHeight = weightType.getDimSize(1);
3798 const int64_t outputHeight = outputType.getDimSize(1);
3799
3800 if (ShapedType::isStatic(inputHeight) &&
3801 ShapedType::isStatic(outputHeight)) {
3802 if (outputHeight !=
3803 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3804 return emitOpError(
3805 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3806 "+ out_pad_top + out_pad_bottom + KH, but got ")
3807 << outputHeight << " != (" << inputHeight << " - 1) * "
3808 << strideY << " + " << outPadTop << " + " << outPadBottom
3809 << " + " << kernelHeight;
3810 }
3811
3812 const int64_t inputWidth = inputType.getDimSize(2);
3813 const int64_t kernelWidth = weightType.getDimSize(2);
3814 const int64_t outputWidth = outputType.getDimSize(2);
3815
3816 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3817 if (outputWidth !=
3818 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3819 return emitOpError(
3820 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3821 "+ out_pad_left + out_pad_right + KW, but got ")
3822 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3823 << " + " << outPadLeft << " + " << outPadRight << " + "
3824 << kernelWidth;
3825 }
3826 }
3827
3828 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3829
3830 if (!biasType)
3831 return success();
3832
3833 const int64_t biasChannels = biasType.getDimSize(0);
3834
3835 // Skip further checks if bias is dynamic
3836 if (biasChannels == ShapedType::kDynamic)
3837 return success();
3838
3839 const int64_t outputChannels = outputType.getDimSize(3);
3840 if (!ShapedType::isDynamic(outputChannels) &&
3841 biasChannels != outputChannels && biasChannels != 1)
3842 return emitOpError(
3843 "bias channels expected to be equal to output channels (")
3844 << outputChannels << ") or 1, got " << biasChannels;
3845
3846 return success();
3847}
3848
3849LogicalResult RescaleOp::verify() {
3850 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3851 if (!inputType) {
3852 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3853 return failure();
3854 }
3855
3856 auto inputElementType =
3857 getStorageElementTypeOrSelf(inputType.getElementType());
3858 if (!mlir::isa<IntegerType>(inputElementType)) {
3859 emitOpError("expect input to have integer element type, got ")
3860 << inputElementType;
3861 return failure();
3862 }
3863
3864 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3865 if (!outputType) {
3866 emitOpError("expect shaped tensor for output, got ")
3867 << getOutput().getType();
3868 return failure();
3869 }
3870
3871 auto outputElementType =
3872 getStorageElementTypeOrSelf(outputType.getElementType());
3873 if (!mlir::isa<IntegerType>(outputElementType)) {
3874 emitOpError("expect output to have integer element type, got ")
3875 << outputElementType;
3876 return failure();
3877 }
3878
3879 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3880 .failed())
3881 return failure();
3882
3883 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3884 .failed())
3885 return failure();
3886
3887 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3888 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3889 return failure();
3890
3891 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3892 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3893 return failure();
3894
3895 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3896 if (!multiplierType) {
3897 emitOpError("expect shaped tensor for multiplier, got ")
3898 << getMultiplier().getType();
3899 return failure();
3900 }
3901
3902 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3903 if (!shiftType) {
3904 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3905 return failure();
3906 }
3907
3908 // multiplier element type must be i32 for scale32 = true
3909 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3910 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3911 << multiplierType.getElementType();
3912 return failure();
3913 }
3914
3915 // multiplier element type must be i16 for scale32 = false
3916 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3918 "expect i16 element type for multiplier for scale32=false, got ")
3919 << multiplierType.getElementType();
3920 return failure();
3921 }
3922
3923 if (!inputType.hasRank())
3924 return success();
3925
3926 // multiplier/shift must have shape = {numChannels},
3927 // where numChannel is 1 if per_channel = false
3928 // otherwise numChannel is dimension in input shape's last axis
3929 int64_t numChannels = 1;
3930 if (getPerChannel()) {
3931 if (inputType.getRank() < 1) {
3932 emitOpError("requires input to be at least rank 1 when per_channel is "
3933 "true, but got rank ")
3934 << inputType.getRank();
3935 return failure();
3936 }
3937 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3938 }
3939
3940 if (!multiplierType.hasRank())
3941 return success();
3942
3943 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3944 // multiplier input has rank 1 by dialect definition
3945 if (multiplierShape[0] != ShapedType::kDynamic &&
3946 multiplierShape[0] != numChannels) {
3947 emitOpError("expect shape of { ")
3948 << numChannels << " } for multiplier input, got { "
3949 << multiplierShape[0] << " }";
3950 return failure();
3951 }
3952
3953 if (!shiftType.hasRank())
3954 return success();
3955
3956 ArrayRef<int64_t> shiftShape = shiftType.getShape();
3957 // shift input has rank 1 by dialect definition
3958 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3959 emitOpError("expect shape of { ")
3960 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3961 return failure();
3962 }
3963
3964 return success();
3965}
3966
3967LogicalResult RescaleOp::inferReturnTypeComponents(
3968 MLIRContext *context, ::std::optional<Location> location,
3969 RescaleOp::Adaptor adaptor,
3970 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3971 ShapeAdaptor inputShape(adaptor.getInput().getType());
3972 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3973 return success();
3974}
3975
3976LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
3977 MLIRContext *context, ::std::optional<Location> location,
3978 CastFromBlockScaledOp::Adaptor adaptor,
3979 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3980 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
3981 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3982 return success();
3983}
3984
3985LogicalResult CastFromBlockScaledOp::verify() {
3986 const Type inputDataType = getInputData().getType();
3987 const Type outputDataType = getResult().getType();
3988 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
3989 return emitOpError() << "require compatible shapes for input_data ("
3990 << inputDataType << ") and "
3991 << "output_data (" << outputDataType << ")";
3992
3993 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
3994
3995 if (inputDataShape.hasRank()) {
3996 const unsigned int blockSize =
3997 BlockSizeAttr::getBlockSizeValue(getBlockSize());
3998 const int64_t inputDataLastDim =
3999 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4000 if (inputDataLastDim % blockSize != 0)
4001 return emitOpError() << "expect last dimension of input_data ("
4002 << inputDataLastDim
4003 << ") to be divisible by block_size (" << blockSize
4004 << ")";
4005
4006 const Type inputScaleType = getInputScale().getType();
4007 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4008
4009 if (inputScaleShape.hasRank()) {
4010 SmallVector<int64_t> inputDataDims, inputScaleDims;
4011 inputDataShape.getDims(inputDataDims);
4012 inputScaleShape.getDims(inputScaleDims);
4013
4014 if (inputDataDims.size() != inputScaleDims.size() ||
4016 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4017 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4018 return emitOpError() << "require compatible shapes for input_data ("
4019 << inputDataType << ") and "
4020 << "input_scale (" << inputScaleType
4021 << ") except for the last dimension";
4022
4023 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4024 inputScaleDims.back()};
4025 if (ShapedType::isStatic(inputDataLastDim) &&
4026 failed(verifyCompatibleDims(dimsToCheck)))
4027 return emitOpError()
4028 << "expect last dimension of input_scale ("
4029 << inputScaleDims.back()
4030 << ") to be equal to last dimension of input_data / block_size ("
4031 << inputDataDims.back() / blockSize << ")";
4032 }
4033 }
4034
4035 return success();
4036}
4037
4038LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4039 MLIRContext *context, ::std::optional<Location> location,
4040 CastToBlockScaledOp::Adaptor adaptor,
4041 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4042 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4043 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4044 if (!inputShape.hasRank())
4045 return success();
4046
4047 // Calculate output_scale shape if ranked input provided
4048 SmallVector<int64_t> outputScaleShape;
4049 inputShape.getDims(outputScaleShape);
4050 const int64_t lastDimLoc = inputShape.getRank() - 1;
4051 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4052 if (ShapedType::isStatic(lastDimSize)) {
4053 const unsigned int blockSize =
4054 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4055 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4056 }
4057 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4058 return success();
4059}
4060
4061LogicalResult CastToBlockScaledOp::verify() {
4062 const Type inputDataType = getInputData().getType();
4063 const Type outputDataType = getResult(0).getType();
4064 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4065 return emitOpError() << "require compatible shapes for input_data ("
4066 << inputDataType << ") and "
4067 << "output_data (" << outputDataType << ")";
4068
4069 const unsigned int blockSize =
4070 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4071 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4072 if (inputDataShape.hasRank()) {
4073 const int64_t inputDataLastDim =
4074 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4075 if (ShapedType::isStatic(inputDataLastDim) &&
4076 inputDataLastDim % blockSize != 0)
4077 return emitOpError() << "expect last dimension of input_data ("
4078 << inputDataLastDim
4079 << ") to be divisible by block_size (" << blockSize
4080 << ")";
4081 }
4082
4083 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4084 const Type outputScaleType = getResult(1).getType();
4085 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4086 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4087 SmallVector<int64_t> outputDataDims, outputScaleDims;
4088 outputDataShape.getDims(outputDataDims);
4089 outputScaleShape.getDims(outputScaleDims);
4090
4091 if (outputDataDims.size() != outputScaleDims.size() ||
4093 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4094 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4095 return emitOpError() << "require compatible shapes for output_data ("
4096 << outputDataType << ") and "
4097 << "output_scale (" << outputScaleType
4098 << ") except for the last dimension";
4099
4100 const int64_t outputDataLastDim = outputDataDims.back();
4101 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4102 outputScaleDims.back()};
4103 if (ShapedType::isStatic(outputDataLastDim) &&
4104 failed(verifyCompatibleDims(dimsToCheck)))
4105 return emitOpError()
4106 << "expect last dimension of output_scale ("
4107 << outputScaleDims.back()
4108 << ") to be equal to last dimension of output_data / block_size ("
4109 << outputDataDims.back() / blockSize << ")";
4110 }
4111
4112 return success();
4113}
4114
4115LogicalResult IfOp::inferReturnTypeComponents(
4116 MLIRContext *context, ::std::optional<Location> location,
4117 IfOp::Adaptor adaptor,
4118 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4119 llvm::SmallVector<tosa::YieldOp> yieldOps;
4120 for (Region *region : adaptor.getRegions()) {
4121 for (auto &block : *region)
4122 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4123 yieldOps.push_back(returnOp);
4124 }
4125
4126 if (yieldOps.empty())
4127 return failure();
4128
4129 // Get the initial type information for the yield op.
4130 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4131 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4132 for (auto operand : yieldOps.front().getOperands()) {
4133 resultKnowledge.push_back(
4134 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4135 }
4136
4137 for (auto yieldOp : yieldOps) {
4138 if (resultKnowledge.size() != yieldOp.getNumOperands())
4139 return failure();
4140
4141 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4142 int32_t index = it.index();
4143 auto meet = ValueKnowledge::meet(
4144 resultKnowledge[index],
4145 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4146 if (!meet)
4147 continue;
4148 resultKnowledge[index] = meet;
4149 }
4150 }
4151
4152 for (const ValueKnowledge &result : resultKnowledge) {
4153 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4154 }
4155
4156 return success();
4157}
4158
4159LogicalResult WhileOp::inferReturnTypeComponents(
4160 MLIRContext *context, ::std::optional<Location> location,
4161 WhileOp::Adaptor adaptor,
4162 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4163 llvm::SmallVector<tosa::YieldOp> yieldOps;
4164 for (auto &block : adaptor.getBodyGraph())
4165 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4166 yieldOps.push_back(returnOp);
4167
4168 // TOSA's while must have a tosa.yield as its terminator. If not found this
4169 // tosa.while is invalid.
4170 if (yieldOps.empty())
4171 return failure();
4172
4173 // Get the initial type information from the operand types.
4174 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4175 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4176 for (auto operand : yieldOps.front().getOperands()) {
4177 resultKnowledge.push_back(
4178 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4179 }
4180
4181 for (auto yieldOp : yieldOps) {
4182 if (resultKnowledge.size() != yieldOp.getNumOperands())
4183 return failure();
4184
4185 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4186 int32_t index = it.index();
4187 if (auto meet = ValueKnowledge::meet(
4188 resultKnowledge[index],
4189 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4190 resultKnowledge[index] = meet;
4191 }
4192 }
4193 }
4194
4195 for (const ValueKnowledge &result : resultKnowledge) {
4196 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4197 }
4198
4199 return success();
4200}
4201
4202std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4203 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4204 return llvm::to_vector<4>(vt.getShape());
4205 return std::nullopt;
4206}
4207
4209 Block::BlockArgListType blocksArgs,
4210 ValueRange initializers,
4211 StringRef prefix = "") {
4212 assert(blocksArgs.size() == initializers.size() &&
4213 "expected same length of arguments and initializers");
4214 if (initializers.empty())
4215 return;
4216
4217 parser << prefix << '(';
4218 llvm::interleaveComma(
4219 llvm::zip(blocksArgs, initializers), parser,
4220 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4221 parser << ")";
4222}
4223
4224// parse and print of IfOp refer to the implementation of SCF dialect.
4225ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4226 // Create the regions for 'then'.
4227 result.regions.reserve(2);
4228 Region *thenRegion = result.addRegion();
4229 Region *elseRegion = result.addRegion();
4230
4231 OpAsmParser::UnresolvedOperand cond;
4232
4233 if (parser.parseOperand(cond))
4234 return failure();
4235
4236 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4237 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4238
4239 // Parse the optional block arguments
4240 OptionalParseResult listResult =
4241 parser.parseOptionalAssignmentList(regionArgs, operands);
4242 if (listResult.has_value() && failed(listResult.value()))
4243 return failure();
4244
4245 // Parse a colon.
4246 if (failed(parser.parseColon()))
4247 return parser.emitError(parser.getCurrentLocation(),
4248 "expected type for condition operand");
4249
4250 // Parse the type of the condition operand
4251 Type condType;
4252 if (failed(parser.parseType(condType)))
4253 return parser.emitError(parser.getCurrentLocation(),
4254 "expected type for condition operand");
4255
4256 // Resolve operand with provided type
4257 if (failed(parser.resolveOperand(cond, condType, result.operands)))
4258 return failure();
4259
4260 // Parse optional block arg types
4261 if (listResult.has_value()) {
4262 FunctionType functionType;
4263
4264 if (failed(parser.parseType(functionType)))
4265 return parser.emitError(parser.getCurrentLocation())
4266 << "expected list of types for block arguments "
4267 << "followed by arrow type and list of return types";
4268
4269 result.addTypes(functionType.getResults());
4270
4271 if (functionType.getNumInputs() != operands.size()) {
4272 return parser.emitError(parser.getCurrentLocation())
4273 << "expected as many input types as operands "
4274 << "(expected " << operands.size() << " got "
4275 << functionType.getNumInputs() << ")";
4276 }
4277
4278 // Resolve input operands.
4279 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4280 parser.getCurrentLocation(),
4281 result.operands)))
4282 return failure();
4283 } else {
4284 // Parse optional results type list.
4285 if (parser.parseOptionalArrowTypeList(result.types))
4286 return failure();
4287 }
4288
4289 // Parse the 'then' region.
4290 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4291 return failure();
4292
4293 // If we find an 'else' keyword then parse the 'else' region.
4294 if (!parser.parseOptionalKeyword("else")) {
4295 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4296 return failure();
4297 }
4298
4299 // Parse the optional attribute list.
4300 if (parser.parseOptionalAttrDict(result.attributes))
4301 return failure();
4302 return success();
4303}
4304
4305void IfOp::print(OpAsmPrinter &p) {
4306 p << " " << getCondition();
4307
4308 printInitializationList(p, getThenGraph().front().getArguments(),
4309 getInputList(), " ");
4310 p << " : ";
4311 p << getCondition().getType();
4312
4313 if (!getInputList().empty()) {
4314 p << " (";
4315 llvm::interleaveComma(getInputList().getTypes(), p);
4316 p << ")";
4317 }
4318 p.printArrowTypeList(getResultTypes());
4319 p << " ";
4320
4321 p.printRegion(getThenGraph());
4322
4323 // Print the 'else' regions if it exists and has a block.
4324 auto &elseRegion = getElseGraph();
4325 if (!elseRegion.empty()) {
4326 p << " else ";
4327 p.printRegion(elseRegion);
4328 }
4329
4330 p.printOptionalAttrDict((*this)->getAttrs());
4331}
4332
4333LogicalResult IfOp::verify() {
4334 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4335 "'then_graph' arguments", getInputList(),
4336 "'input_list'")
4337 .failed())
4338 return failure();
4339
4340 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4341 "'else_graph' arguments", getInputList(),
4342 "'input_list'")
4343 .failed())
4344 return failure();
4345
4346 // MLIR will verify the absence of the terminator for us if otherwise.
4347 if (getThenGraph().front().mightHaveTerminator()) {
4348 auto thenYield =
4349 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4350 if (thenYield && errorIfTypeOrShapeMismatch(
4351 *this, thenYield.getInputs(), "'then_graph' results",
4352 getOutputList(), "'output_list'")
4353 .failed())
4354 return failure();
4355 }
4356
4357 // MLIR will verify the absence of the terminator for us if otherwise.
4358 if (getElseGraph().front().mightHaveTerminator()) {
4359 auto elseYield =
4360 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4361 if (elseYield && errorIfTypeOrShapeMismatch(
4362 *this, elseYield.getInputs(), "'else_graph' results",
4363 getOutputList(), "'output_list'")
4364 .failed())
4365 return failure();
4366 }
4367
4368 auto condType = getCondition().getType();
4369 if (errorIfShapeNotSizeOne(*this, condType).failed())
4370 return emitOpError() << "'condition' must be a size 1 tensor, got "
4371 << condType;
4372
4373 return success();
4374}
4375
4376LogicalResult WhileOp::verify() {
4377 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4378 getOutputList(), "'output_list'")
4379 .failed())
4380 return failure();
4381
4382 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4383 "'cond_graph' arguments", getInputList(),
4384 "'input_list'")
4385 .failed())
4386 return failure();
4387
4388 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4389 "'body_graph' arguments", getInputList(),
4390 "'input_list'")
4391 .failed())
4392 return failure();
4393
4394 if (getBodyGraph().front().mightHaveTerminator()) {
4395 auto bodyYield =
4396 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4397 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4398 "'body_graph' results",
4399 getInputList(), "'input_list'")
4400 .failed())
4401 return failure();
4402 }
4403
4404 // Condition block output must be a single element tensor with a single bool
4405 // value.
4406 if (!getCondGraph().front().mightHaveTerminator())
4407 return success();
4408
4409 auto condYield =
4410 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4411 if (!condYield)
4412 return success();
4413
4414 if (condYield.getInputs().size() != 1)
4415 return emitOpError() << "require 'cond_graph' only have one result";
4416
4417 auto condOutType = condYield.getInputs()[0].getType();
4418 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4419 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4420 << condOutType;
4421
4422 if (!getElementTypeOrSelf(condOutType).isInteger(1))
4423 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4424 << condOutType;
4425
4426 return success();
4427}
4428
4429LogicalResult ReverseOp::verify() {
4430 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4431 /* outType = */ getOutput().getType())
4432 .failed())
4433 return failure();
4434 TensorType inputType = getInput1().getType();
4435 TensorType outputType = getOutput().getType();
4436 int32_t reverseAxis = getAxis();
4437
4438 if (reverseAxis < 0)
4439 return emitOpError("expected non-negative reverse axis");
4440 if (inputType.hasRank()) {
4441 int64_t inputRank = inputType.getRank();
4442 // We allow for a special case where the input/output shape has rank 0 and
4443 // axis is also 0.
4444 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4445 return emitOpError("expect input tensor rank (")
4446 << inputRank << ") to be larger than reverse axis (" << reverseAxis
4447 << ")";
4448 }
4449 if (outputType.hasRank()) {
4450 int64_t outputRank = outputType.getRank();
4451 if (inputType.hasRank() && outputRank != inputType.getRank())
4452 return emitOpError(
4453 "expect output tensor rank to be equal to input tensor rank");
4454 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4455 return emitOpError("expect output tensor rank (")
4456 << outputRank << ") to be larger than reverse axis ("
4457 << reverseAxis << ")";
4458 }
4459 return success();
4460}
4461
4462LogicalResult tosa::SelectOp::verify() {
4463 // verify input2 and input3 have same element type as output
4464 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4465 /* outType = */ getOutput().getType())
4466 .failed() ||
4467 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4468 /* outType = */ getOutput().getType())
4469 .failed()) {
4470 return failure();
4471 }
4472 // verify input1 has element type of bool
4473 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4474 if (!predicateType) {
4475 return emitOpError("expect shaped tensor for input1, got ")
4476 << getInput1().getType();
4477 }
4478 auto predicateElementType = predicateType.getElementType();
4479 if (!predicateElementType.isInteger(1)) {
4480 return emitOpError("expect element type of bool for input1, got ")
4481 << predicateElementType;
4482 }
4483
4484 return success();
4485}
4486
4487LogicalResult tosa::VariableReadOp::verify() {
4488 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4489 .failed())
4490 return failure();
4491
4492 return success();
4493}
4494
4495LogicalResult tosa::VariableWriteOp::verify() {
4496 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4497 .failed())
4498 return failure();
4499
4500 return success();
4501}
4502
4503// parse and print of WhileOp refer to the implementation of SCF dialect.
4504ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4505 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4506 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4507 Region *cond = result.addRegion();
4508 Region *body = result.addRegion();
4509
4510 OptionalParseResult listResult =
4511 parser.parseOptionalAssignmentList(regionArgs, operands);
4512 if (listResult.has_value() && failed(listResult.value()))
4513 return failure();
4514
4515 FunctionType functionType;
4516 SMLoc typeLoc = parser.getCurrentLocation();
4517 if (failed(parser.parseColonType(functionType)))
4518 return failure();
4519
4520 result.addTypes(functionType.getResults());
4521
4522 if (functionType.getNumInputs() != operands.size()) {
4523 return parser.emitError(typeLoc)
4524 << "expected as many input types as operands "
4525 << "(expected " << operands.size() << " got "
4526 << functionType.getNumInputs() << ")";
4527 }
4528
4529 // Resolve input operands.
4530 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4531 parser.getCurrentLocation(),
4532 result.operands)))
4533 return failure();
4534
4535 // Propagate the types into the region arguments.
4536 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4537 regionArgs[i].type = functionType.getInput(i);
4538
4539 return failure(parser.parseRegion(*cond, regionArgs) ||
4540 parser.parseKeyword("do") || parser.parseRegion(*body) ||
4541 parser.parseOptionalAttrDictWithKeyword(result.attributes));
4542}
4543
4544void WhileOp::print(OpAsmPrinter &parser) {
4545 printInitializationList(parser, getCondGraph().front().getArguments(),
4546 getInputList(), " ");
4547 parser << " : ";
4548 parser.printFunctionalType(getInputList().getTypes(),
4549 getResults().getTypes());
4550 parser << ' ';
4551 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4552 parser << " do ";
4553 parser.printRegion(getBodyGraph());
4554 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4555}
4556
4557// Create a rank-1 const tensor for zero point of the source tensor.
4558std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4559 Location loc,
4560 Type srcElemType,
4561 int64_t zp) {
4562 srcElemType = getStorageElementTypeOrSelf(srcElemType);
4563 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4564 if (llvm::isa<FloatType>(srcElemType)) {
4565 auto zpAttr = DenseElementsAttr::get(
4566 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4567 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4568 }
4569 if (llvm::isa<IntegerType>(srcElemType)) {
4570 auto zpAttr =
4571 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4572 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4573 }
4574 llvm::errs() << "zero point is not allowed for unsupported data types\n";
4575 return std::nullopt;
4576}
4577
4578//===----------------------------------------------------------------------===//
4579// TOSA Shape and Shape Operators Helper functions.
4580//===----------------------------------------------------------------------===//
4581
4583 return mlir::isa<tosa::shapeType>(t);
4584}
4585
4586LogicalResult
4587mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
4588 int rank) {
4589 if (rank < 0)
4590 return emitError() << "invalid rank (must be >= 0): " << rank;
4591 return success();
4592}
4593
4595 for (auto v : op->getOperands()) {
4596 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4597 Operation *definingOp = v.getDefiningOp();
4598 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4599 return op->emitOpError("shape operand is not compile time resolvable");
4600 }
4601 }
4602 }
4603 return success();
4604}
4605
4607 for (auto type : op->getOperandTypes()) {
4608 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4609 return op->emitOpError("must have operands with tosa shape type");
4610 }
4611 }
4612 for (auto type : op->getResultTypes()) {
4613 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4614 return op->emitOpError("must have result with tosa shape type");
4615 }
4616 }
4617 return success();
4618}
4619
4620LogicalResult
4622 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
4623 failed(verifyTosaShapeOperator(op)))
4624 return failure();
4625
4626 // delegate function that returns rank of shape type
4627 auto getRank = [](const Type type) {
4628 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4629 };
4630 auto operandTypes = op->getOperandTypes();
4631 auto resultTypes = op->getResultTypes();
4632
4633 auto rank = getRank(*op->getOperandTypes().begin());
4634 for (auto type : operandTypes) {
4635 if (getRank(type) != rank) {
4636 return op->emitOpError("operands don't have matching ranks");
4637 }
4638 }
4639 for (auto type : resultTypes) {
4640 if (getRank(type) != rank) {
4641 return op->emitOpError("result shape has different rank than operands");
4642 }
4643 }
4644 return success();
4645}
4646
4647//===----------------------------------------------------------------------===//
4648// TOSA Shape Operators verify functions.
4649//===----------------------------------------------------------------------===//
4650
4651LogicalResult tosa::ConstShapeOp::verify() {
4652 // check one dimensional rank
4653 auto valuesRank = getValues().getType().getRank();
4654 if (valuesRank != 1)
4655 return emitOpError("expect elements in attribute values with rank 1");
4656 // check that number of elements in values attr equal to rank of result shape
4657 auto count = getValues().getNumElements();
4658 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4659 if (count != rank && (count != 1 || rank != 0)) {
4660 return emitOpError("expect number of elements in attribute values (")
4661 << count << ") to be equal to the rank (" << rank
4662 << ") for the result shape type";
4663 }
4664 return success();
4665}
4666
4667//===----------------------------------------------------------------------===//
4668// TOSA Attribute Definitions.
4669//===----------------------------------------------------------------------===//
4670
4671#define GET_ATTRDEF_CLASSES
4672#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4673
4674//===----------------------------------------------------------------------===//
4675// TOSA Type Definitions.
4676//===----------------------------------------------------------------------===//
4677#define GET_TYPEDEF_CLASSES
4678#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4679
4680//===----------------------------------------------------------------------===//
4681// TOSA Operator Definitions.
4682//===----------------------------------------------------------------------===//
4683
4684#define GET_OP_CLASSES
4685#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:565
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1296
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:984
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:136
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3174
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:574
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:948
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3199
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:620
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:957
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3389
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1410
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:557
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1424
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3224
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3292
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:2671
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1274
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:766
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2601
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:725
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3280
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:563
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3190
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1455
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1370
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1250
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1325
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:906
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:2629
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1050
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1544
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:193
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:207
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OperandRange operand_range
Definition Operation.h:371
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition TosaOps.cpp:4606
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:4621
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:4594
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:609
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4558
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:4582
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:594
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45