MLIR 22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
24#include "mlir/IR/Matchers.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31#include <numeric>
32
33using namespace mlir;
34using namespace mlir::tosa;
35
36#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38
39//===----------------------------------------------------------------------===//
40// Tosa dialect interface includes.
41//===----------------------------------------------------------------------===//
42
43#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47
48namespace {
49#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// Dialect Function Inliner Interface.
53//===----------------------------------------------------------------------===//
54struct TosaInlinerInterface : public DialectInlinerInterface {
56
57 //===--------------------------------------------------------------------===//
58 // Analysis Hooks.
59 //===--------------------------------------------------------------------===//
60
61 /// All operations can be inlined by default.
62 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63 IRMapping &map) const final {
64 return true;
65 }
66
67 /// All regions with If and While parent operators can be inlined.
68 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69 IRMapping &map) const final {
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
72 }
73};
74
75/// This class implements the bytecode interface for the Tosa dialect.
76struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77 TosaDialectBytecodeInterface(Dialect *dialect)
78 : BytecodeDialectInterface(dialect) {}
79
80 //===--------------------------------------------------------------------===//
81 // Attributes
82
83 Attribute readAttribute(DialectBytecodeReader &reader) const override {
84 return ::readAttribute(getContext(), reader);
85 }
86
87 LogicalResult writeAttribute(Attribute attr,
88 DialectBytecodeWriter &writer) const override {
89 return ::writeAttribute(attr, writer);
90 }
91
92 //===--------------------------------------------------------------------===//
93 // Types
94
95 Type readType(DialectBytecodeReader &reader) const override {
96 return ::readType(getContext(), reader);
97 }
98
99 LogicalResult writeType(Type type,
100 DialectBytecodeWriter &writer) const override {
101 return ::writeType(type, writer);
102 }
103
104 void writeVersion(DialectBytecodeWriter &writer) const final {
105 // TODO: Populate.
106 }
107
108 std::unique_ptr<DialectVersion>
109 readVersion(DialectBytecodeReader &reader) const final {
110 // TODO: Populate
111 reader.emitError("Dialect does not support versioning");
112 return nullptr;
113 }
114
115 LogicalResult upgradeFromVersion(Operation *topLevelOp,
116 const DialectVersion &version) const final {
117 return success();
118 }
119};
120
121} // namespace
122
123//===----------------------------------------------------------------------===//
124// TOSA control flow support.
125//===----------------------------------------------------------------------===//
126
127/// Returns the while loop body.
128SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129 return {&getBodyGraph()};
130}
131
132//===----------------------------------------------------------------------===//
133// TOSA variable operator support.
134//===----------------------------------------------------------------------===//
135
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
139 }));
140}
141
142// returns type of variable op
143RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144 Type elementType = variableOp.getType();
145 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147 return RankedTensorType::get(shape, elementType);
148}
149
150//===----------------------------------------------------------------------===//
151// Tosa dialect initialization.
152//===----------------------------------------------------------------------===//
153
154void TosaDialect::initialize() {
155 addTypes<
156#define GET_TYPEDEF_LIST
157#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158 >();
159 addOperations<
160#define GET_OP_LIST
161#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162 >();
163 addAttributes<
164#define GET_ATTRDEF_LIST
165#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166 >();
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
176}
177
178Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
179 Type type, Location loc) {
180 // Tosa dialect constants only support ElementsAttr unlike standard dialect
181 // constant which supports all attributes.
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
185 }
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
189 return nullptr;
190}
191
192//===----------------------------------------------------------------------===//
193// Parsers and printers
194//===----------------------------------------------------------------------===//
195
196namespace {
197
198ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199 DenseElementsAttr &varShapeAttr,
200 TypeAttr &typeAttr) {
201 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
203 return parser.emitError(parser.getCurrentLocation())
204 << "expected ranked type";
205
206 auto elementType = shapedType.getElementType();
207 typeAttr = TypeAttr::get(elementType);
208 ArrayRef<int64_t> shape = shapedType.getShape();
209 Builder builder(parser.getContext());
210 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211 return success();
212 }
213 return parser.emitError(parser.getCurrentLocation())
214 << "expected shaped type";
215}
216
217} // namespace
218
219// parses the optional initial value or type for a tosa variable
220// with initial value:
221// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222//
223// without initial value:
224// tosa.variable @name : tensor<1x8xf32>
226 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227 Attribute &initialValueAttr) {
228 if (succeeded(parser.parseOptionalEqual())) {
229 if (failed(parser.parseAttribute(initialValueAttr))) {
230 return parser.emitError(parser.getCurrentLocation())
231 << "expected attribute";
232 }
233 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235 typeAttr);
236 }
237 return parser.emitError(parser.getCurrentLocation())
238 << "expected Typed attr";
239 }
240
241 initialValueAttr = nullptr;
242 Type parsedType;
243 if (failed(parser.parseColonType(parsedType))) {
244 return parser.emitError(parser.getCurrentLocation())
245 << "expected type after colon";
246 }
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248}
249
251 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252 TypeAttr typeAttr, Attribute initialValueAttr) {
253 bool needsSpace = false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255 auto shape =
256 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
259 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260 auto tensorTypeAttr = TypeAttr::get(tensorType);
261 p << ": ";
262 p.printAttribute(tensorTypeAttr);
263 needsSpace = true; // subsequent attr value needs a space separator
264 }
265 if (initialValueAttr) {
266 if (needsSpace)
267 p << ' ';
268 p << "= ";
269 p.printAttribute(initialValueAttr);
270 }
271}
272
273namespace {
274
275// parse attributes with special handling for tosa enum attributes
276template <typename EnumType>
277ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
278 NamedAttrList &outAttrs) {
279 llvm::StringRef name;
280 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
281 return failure();
282
283 // special handling: rounding_mode accepts a *bare* RoundingMode enum
284 // keyword.
285 llvm::StringRef kw;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name == "rounding_mode" &&
288 succeeded(parser.parseOptionalKeyword(&kw))) {
289 auto sym = symbolizeRoundingMode(kw);
290 if (!sym)
291 return parser.emitError(parser.getCurrentLocation())
292 << "invalid rounding_mode value: " << kw;
293 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
294 outAttrs.push_back(NamedAttribute(name, attr));
295 return success();
296 }
297 }
298 // special handling: mode accepts a *bare* ResizeMode enum keyword.
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
300 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
301 auto sym = symbolizeResizeMode(kw);
302 if (!sym)
303 return parser.emitError(parser.getCurrentLocation())
304 << "invalid resize mode value: " << kw;
305 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
306 outAttrs.push_back(NamedAttribute(name, attr));
307 return success();
308 }
309 }
310 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
311 // keyword.
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
313 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
314 auto sym = symbolizeNanPropagationMode(kw);
315 if (!sym)
316 return parser.emitError(parser.getCurrentLocation())
317 << "invalid nan_mode value: " << kw;
318 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
319 outAttrs.push_back(NamedAttribute(name, attr));
320 return success();
321 }
322 }
323
324 // special handling: block_size accepts a *bare* BlockSizeMode enum
325 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
326 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
327 auto sym = symbolizeBlockSize(kw);
328 if (!sym)
329 return parser.emitError(parser.getCurrentLocation())
330 << "invalid block_size value: " << kw;
331 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
332 outAttrs.push_back(NamedAttribute(name, attr));
333 return success();
334 }
335 }
336
337 // Default path: parse any normal attribute literal, including fully qualified
338 // enum keyword
339 Attribute attr;
340 return parser.parseAttribute(attr, name, outAttrs);
341}
342
343template <typename EnumType>
344ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
345 // parse operands
347 if (parser.parseCommaSeparatedList(
348 [&]() { return parser.parseOperand(operands.emplace_back()); }))
349 return failure();
350
351 // Parse { attr-dict } with special handling for enum bare token
352 NamedAttrList attrs;
353 if (succeeded(parser.parseOptionalLBrace()) &&
354 failed(parser.parseOptionalRBrace())) {
355 do {
356 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
357 return failure();
358 } while (succeeded(parser.parseOptionalComma()));
359 if (parser.parseRBrace())
360 return failure();
361 }
362
363 FunctionType fnTy;
364 if (parser.parseColonType(fnTy))
365 return failure();
366
367 // Resolve operands and types
368 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
369 parser.getCurrentLocation(),
370 result.operands)))
371 return failure();
372
373 result.addTypes(fnTy.getResults());
374 result.addAttributes(attrs);
375
376 return success();
377}
378
379void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
380 parser << namedAttr.getName().strref() << " = ";
381 auto attr = namedAttr.getValue();
382 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
383 parser << roundingModeAttr.getValue();
384 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
385 parser << resizeModeAttr.getValue();
386 } else if (auto nanPropagationModeAttr =
387 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
388 parser << nanPropagationModeAttr.getValue();
389 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
390 parser << blockSizeAttr.getValue();
391 } else {
392 parser.printAttribute(attr);
393 }
394}
395
396// print with special handling for default valued NanPropagationMode attribute
397void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
398 parser << " ";
399 parser.printOperands(op->getOperands());
400
401 NamedAttrList toPrint(op->getAttrs());
402 // remove default NanPropagate attribute
403 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
404 for (auto attr : op->getAttrs()) {
405 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
406 if (nanAttr.getValue() == kDefaultNanValue) {
407 // elide from toPrint
408 toPrint.erase(attr.getName());
409 break;
410 }
411 }
412 }
413
414 if (!toPrint.empty()) {
415 parser << " {";
416 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
417 printNamedAttr(parser, namedAttr);
418 });
419 parser << "}";
420 }
421
422 parser << " : ";
423 parser.printFunctionalType(op);
424}
425
426// print with special handling for enums: RoundingMode, ResizeMode
427void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
428 parser << " ";
429 parser.printOperands(op->getOperands());
430
431 if (!op->getAttrs().empty()) {
432 parser << " {";
433 llvm::interleaveComma(op->getAttrs(), parser,
434 [&](const NamedAttribute namedAttr) {
435 printNamedAttr(parser, namedAttr);
436 });
437 parser << "}";
438 }
439
440 parser << " : ";
441 parser.printFunctionalType(op);
442}
443
444} // namespace
445
446ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
447 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
448}
449
450void RescaleOp::print(OpAsmPrinter &parser) {
451 printWithEnumHandling(parser, *this);
452}
453
454ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
455 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
456}
457
458void ApplyScaleOp::print(OpAsmPrinter &parser) {
459 printWithEnumHandling(parser, *this);
460}
461
462ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
463 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
464}
465
466void ResizeOp::print(OpAsmPrinter &parser) {
467 printWithEnumHandling(parser, *this);
468}
469
470ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
471 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
472}
473
474void ArgMaxOp::print(OpAsmPrinter &parser) {
475 printWithNanPropagationHandling(parser, *this);
476}
477
478ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
479 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
480}
481
482void MaxPool2dOp::print(OpAsmPrinter &parser) {
483 printWithNanPropagationHandling(parser, *this);
484}
485
486ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
487 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
488}
489
490void ClampOp::print(OpAsmPrinter &parser) {
491 printWithNanPropagationHandling(parser, *this);
492}
493
494ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
495 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
496}
497
498void MaximumOp::print(OpAsmPrinter &parser) {
499 printWithNanPropagationHandling(parser, *this);
500}
501
502ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
503 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
504}
505
506void MinimumOp::print(OpAsmPrinter &parser) {
507 printWithNanPropagationHandling(parser, *this);
508}
509
510ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
511 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
512}
513
514void ReduceMaxOp::print(OpAsmPrinter &parser) {
515 printWithNanPropagationHandling(parser, *this);
516}
517
518ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
519 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
520}
521
522void ReduceMinOp::print(OpAsmPrinter &parser) {
523 printWithNanPropagationHandling(parser, *this);
524}
525
526ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
528 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
529}
530
531void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
532 printWithEnumHandling(parser, *this);
533}
534
535ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
537 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
538}
539
540void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
541 printWithEnumHandling(parser, *this);
542}
543
544ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
546 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
547}
548
549void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
550 printWithEnumHandling(parser, *this);
551}
552
553//===----------------------------------------------------------------------===//
554// Tosa utilities.
555//===----------------------------------------------------------------------===//
556
557static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
558 if (lhs % rhs != 0)
559 return std::nullopt;
560 return lhs / rhs;
561}
562
564 auto srcType = getElementTypeOrSelf(type);
565 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
566 srcType = quantType.getStorageType();
567 return srcType;
568}
569
573
574static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
575 Value valZp, StringRef name) {
577 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
578
579 bool bothInts =
580 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
581 bool sameBitWidth =
582 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
583
584 if (!bothInts || !sameBitWidth) {
585 return op->emitOpError()
586 << "expected " << name << " and " << name
587 << "_zp to both be integer of the same bitwidth, but got " << eType
588 << " vs. " << eZpType;
589 }
590 return success();
591}
592
593// Create a pad-const const tensor with value of `val` of required data-type
595 Value src, int32_t val) {
596 const auto srcType = getElementTypeOrSelf(src);
597 const auto srcElemType = getStorageElementTypeOrSelf(src);
598 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
599 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
600 const auto padConstAttr{
601 llvm::isa<FloatType>(srcElemType)
602 ? DenseElementsAttr::get(padConstEType,
603 builder.getFloatAttr(srcElemType, val))
604 : DenseElementsAttr::get(padConstEType,
605 builder.getIntegerAttr(srcElemType, val))};
606 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
607}
608
610 if (dyn_cast<tosa::mxint8Type>(type))
611 return 8;
612 return type.getIntOrFloatBitWidth();
613}
614
615//===----------------------------------------------------------------------===//
616// TOSA Operator Verifiers.
617//===----------------------------------------------------------------------===//
618
619template <typename T>
620static LogicalResult verifyConvOp(T op) {
621 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
622 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
623
624 auto inputEType = inputType.getElementType();
625 auto weightEType = weightType.getElementType();
626 auto biasEType =
627 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
628 auto resultEType =
629 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
630 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
631 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
632
633 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
634 inputEType = quantType.getStorageType();
635
636 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
637 weightEType = quantType.getStorageType();
638
639 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
640 biasEType = quantType.getStorageType();
641
642 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
643 resultEType = quantType.getStorageType();
644
645 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
646 // for now, only enforce bias element type == result element type for
647 // float types.
648 op.emitOpError(
649 "expect both bias and result to have same element type, got ")
650 << biasEType << " and " << resultEType;
651 return failure();
652 }
653
654 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
655 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
656 if (inputEType != weightEType) {
657 op.emitOpError(
658 "expect both input and weight to have same element type, got ")
659 << inputEType << " and " << weightEType;
660 return failure();
661 }
662 }
663
664 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
665 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
666
667 // Either both must be float or both non-float.
668 if (inputIsFloat != weightIsFloat) {
669 op.emitOpError(
670 "expect both input and weight to be float or not together, got ")
671 << inputEType << " and " << weightEType;
672 return failure();
673 }
674
675 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
676 if (inputEType != inputZpEType) {
677 return op.emitOpError("expect both input and its zero point are the same "
678 "element type, got ")
679 << inputEType << " and " << inputZpEType;
680 }
681
682 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
683 if (weightEType != weightZpEType) {
684 return op.emitOpError("expect both weight and its zero point are the same "
685 "element type, got ")
686 << weightEType << " and " << weightZpEType;
687 }
688
689 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
690 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
691 return failure();
692
693 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
694 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
695 return failure();
696
697 return success();
698}
699
700LogicalResult tosa::ConstOp::verify() {
701
702 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
703 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
704
705 if (!attrType || !outputType) {
706 emitOpError("expected tensors for attr/result type");
707 return failure();
708 }
709
710 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
711 outputType.getElementType())) {
712 if (result.getStorageType() == attrType.getElementType())
713 return success();
714 }
715
716 if (attrType.getElementType() != outputType.getElementType()) {
717 emitOpError("expected same attr/result element types");
718 return failure();
719 }
720
721 return success();
722}
723
724template <typename T>
725static LogicalResult verifyConvOpModes(T op) {
726 auto inputEType =
727 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
728
729 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
730 inputEType = quantType.getStorageType();
731
732 auto accType = op.getAccType();
733 if (inputEType.isInteger(8) && !accType.isInteger(32))
734 return op.emitOpError("accumulator type for i8 tensor is not i32");
735
736 if (inputEType.isInteger(16) && !accType.isInteger(48))
737 return op.emitOpError("accumulator type for i16 tensor is not i48");
738
739 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
740 return op.emitOpError("accumulator type for f8 tensor is not f16");
741
742 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
743 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
744
745 if (inputEType.isBF16() && !accType.isF32())
746 return op.emitOpError("accumulator type for bf16 tensor is not f32");
747
748 if (inputEType.isF32() && !accType.isF32())
749 return op.emitOpError("accumulator type for f32 tensor is not f32");
750
751 auto resultEType =
752 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
753
754 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
755 resultEType = quantType.getStorageType();
756
757 return success();
758}
759
760//===----------------------------------------------------------------------===//
761// ERROR_IF functions.
762// ERROR_IF is a predicate that must set an error if the condition holds.
763//===----------------------------------------------------------------------===//
764
765template <typename T>
766static LogicalResult verifyConvOpErrorIf(T op) {
767 llvm::ArrayRef<int64_t> padding = op.getPad();
768 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
769 return op.emitOpError("expect all padding values to be >= 0, got ")
770 << padding;
771
772 llvm::ArrayRef<int64_t> strides = op.getStride();
773 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
774 return op.emitOpError("expect all stride values to be >= 1, got ")
775 << strides;
776
777 llvm::ArrayRef<int64_t> dilations = op.getDilation();
778 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
779 return op.emitOpError("expect all dilation values to be >= 1, got ")
780 << dilations;
781
782 const RankedTensorType outputType =
783 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
784 if (!outputType)
785 // Skip following checks if output is not ranked
786 return success();
787
788 const RankedTensorType inputType =
789 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
790 const RankedTensorType weightType =
791 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
792
793 if (inputType && weightType) {
794 const auto verifyOutputSize =
795 [&op](const int64_t inputSize, const int64_t kernelSize,
796 const int64_t outputSize, const int64_t padBefore,
797 const int64_t padAfter, const int64_t stride,
798 const int64_t dilation, const llvm::StringRef dimName,
799 const llvm::StringRef dimAxis,
800 const llvm::StringRef padBeforeName,
801 const llvm::StringRef padAfterName) -> LogicalResult {
802 if (inputSize == ShapedType::kDynamic ||
803 kernelSize == ShapedType::kDynamic)
804 return success();
805
806 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
807
808 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
809 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
810 stride);
811 if (!calculatedOutSizeMinusOne.has_value())
812 return op.emitOpError("expected input_")
813 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
814 << padAfterName << " - (kernel_" << dimName
815 << " - 1) * dilation_" << dimAxis
816 << " to be wholly divisible by stride_" << dimAxis << ", got ("
817 << inputSize << " - 1 + " << padBefore << " + " << padAfter
818 << " - (" << kernelSize << " - 1) * " << dilation << ") / "
819 << stride;
820
821 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
822 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
823 return op.emitOpError("calculated output ")
824 << dimName << " did not match expected: "
825 << "calculated=" << calculatedOutSize
826 << ", expected=" << outputSize;
827
828 return success();
829 };
830
831 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
832 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
833 if (failed(verifyOutputSize(
834 inputType.getDimSize(1), weightType.getDimSize(1),
835 outputType.getDimSize(1), padding[0], padding[1], strides[0],
836 dilations[0], "height", "y", "top", "bottom")))
837 return failure();
838
839 if (failed(verifyOutputSize(
840 inputType.getDimSize(2), weightType.getDimSize(2),
841 outputType.getDimSize(2), padding[2], padding[3], strides[1],
842 dilations[1], "width", "x", "left", "right")))
843 return failure();
844 }
845
846 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
847 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
848 if (failed(verifyOutputSize(
849 inputType.getDimSize(1), weightType.getDimSize(0),
850 outputType.getDimSize(1), padding[0], padding[1], strides[0],
851 dilations[0], "height", "y", "top", "bottom")))
852 return failure();
853
854 if (failed(verifyOutputSize(
855 inputType.getDimSize(2), weightType.getDimSize(1),
856 outputType.getDimSize(2), padding[2], padding[3], strides[1],
857 dilations[1], "width", "x", "left", "right")))
858 return failure();
859 }
860
861 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
862 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
863 if (failed(verifyOutputSize(
864 inputType.getDimSize(1), weightType.getDimSize(1),
865 outputType.getDimSize(1), padding[0], padding[1], strides[0],
866 dilations[0], "depth", "d", "front", "back")))
867 return failure();
868
869 if (failed(verifyOutputSize(
870 inputType.getDimSize(2), weightType.getDimSize(2),
871 outputType.getDimSize(2), padding[2], padding[3], strides[1],
872 dilations[1], "height", "y", "top", "bottom")))
873 return failure();
874
875 if (failed(verifyOutputSize(
876 inputType.getDimSize(3), weightType.getDimSize(3),
877 outputType.getDimSize(3), padding[4], padding[5], strides[2],
878 dilations[2], "width", "x", "left", "right")))
879 return failure();
880 }
881 }
882
883 const RankedTensorType biasType =
884 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
885 if (!biasType)
886 // Skip following checks if bias is not ranked
887 return success();
888
889 const int64_t biasChannels = biasType.getDimSize(0);
890 const int64_t outputChannels =
891 outputType.getDimSize(outputType.getRank() - 1);
892 if (biasChannels == ShapedType::kDynamic ||
893 outputChannels == ShapedType::kDynamic)
894 // Skip following checks if biasChannels or outputChannels is dynamic dim
895 return success();
896
897 if (biasChannels != outputChannels && biasChannels != 1)
898 return op.emitOpError(
899 "bias channels expected to be equal to output channels (")
900 << outputChannels << ") or 1, got " << biasChannels;
901
902 return success();
903}
904
905// Verify whether same type and shape of the given two types.
906static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
907 StringRef name1, Type type2,
908 StringRef name2) {
909 auto shapeType1 = dyn_cast<ShapedType>(type1);
910 auto shapeType2 = dyn_cast<ShapedType>(type2);
911 if (!shapeType1 || !shapeType2)
912 return failure();
913
914 auto elemType1 = shapeType1.getElementType();
915 auto elemType2 = shapeType2.getElementType();
916 if (elemType1 != elemType2)
917 return op->emitOpError()
918 << "require same element type for " << name1 << " (" << elemType1
919 << ") and " << name2 << " (" << elemType2 << ")";
920
921 if (failed(verifyCompatibleShape(type1, type2)))
922 return op->emitOpError()
923 << "require same shapes for " << name1 << " (" << type1 << ") and "
924 << name2 << " (" << type2 << ")";
925
926 return success();
927}
928
929// Verify whether same length, type, and shape of the given two tensor lists.
930static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
931 StringRef name1,
932 ValueRange list2,
933 StringRef name2) {
934 if (list1.size() != list2.size())
935 return op->emitOpError()
936 << "require same number of values in " << name1 << " ("
937 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
938
939 for (auto [type1, type2] :
940 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
941 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
942 return failure();
943 }
944
945 return success();
946}
947
948static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
949 ShapeAdaptor shapeAdaptor(type);
950 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
951 return success();
952
953 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
954}
955
956template <typename T>
957static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
958 Operation *symTableOp =
959 op->template getParentWithTrait<OpTrait::SymbolTable>();
960 if (!symTableOp)
961 // If the operation is not the scope of a symbol table, we cannot
962 // verify it against it's declaration.
963 return success();
964
965 SymbolTable symTable(symTableOp);
966 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
967
968 // Verify prior declaration
969 if (!varOp)
970 return op->emitOpError("'")
971 << op.getName() << "' has not been declared by 'tosa.variable'";
972
973 // Verify type and shape
974 auto variableType = getVariableType(varOp);
975 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
976 "the input tensor")
977 .failed())
978 return failure();
979 return success();
980}
981
982// verify that inType and outType have same element types
983template <typename T>
984static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
985 StringRef aName = "input",
986 StringRef bName = "output") {
987 auto aTType = llvm::dyn_cast<TensorType>(aType);
988 auto bTType = llvm::dyn_cast<TensorType>(bType);
989 if (!aTType) {
990 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
991 return failure();
992 }
993 if (!bTType) {
994 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
995 return failure();
996 }
997 auto aElementType = aTType.getElementType();
998 auto bElementType = bTType.getElementType();
999 auto aQuantType =
1000 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1001 auto bQuantType =
1002 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1003 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1004 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1005 aElementType != bElementType) {
1006 // only check if both element types are int/index/float/UniformQuantized
1007 // eg, not sure how to check quant::QuantizedType
1008 // this happens in test_conv2d_q_grouped_convolution in
1009 // tfl-to-tosa-pipeline.mlir
1010 op.emitOpError("expect ")
1011 << aName << " and " << bName << " to have same element type, got "
1012 << aElementType << " and " << bElementType;
1013 return failure();
1014 }
1015 return success();
1016}
1017
1018LogicalResult tosa::ArgMaxOp::verify() {
1019 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1020
1021 // Ensure output is of 32-bit integer
1022 if (const auto resultETy = resultType.getElementType();
1023 !resultETy.isIntOrIndex())
1024 return emitOpError("result tensor is not of integer type");
1025
1026 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1027 if (!inputType.hasRank())
1028 return success();
1029
1030 // Ensure axis is within the tensor rank
1031 const int64_t axis = getAxisAttr().getInt();
1032 if (((axis < 0) || axis >= inputType.getRank()))
1033 return emitOpError("specified axis is outside the rank of the tensor");
1034
1035 if (!resultType.hasRank())
1036 return success();
1037
1038 const ArrayRef<int64_t> inputShape = inputType.getShape();
1039 const ArrayRef<int64_t> outputShape = resultType.getShape();
1040 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1041 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1042 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1043 return emitOpError("expected output shape '")
1044 << expectedOutputShape << "', got '" << outputShape << "'";
1045
1046 return success();
1047}
1048
1049template <typename T>
1050static LogicalResult verifyPoolingOp(T op) {
1051 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1052 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1053 return op.emitOpError("expect all kernel values to be >= 1, got ")
1054 << kernel;
1055
1056 const llvm::ArrayRef<int64_t> strides = op.getStride();
1057 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1058 return op.emitOpError("expect all stride values to be >= 1, got ")
1059 << strides;
1060
1061 const llvm::ArrayRef<int64_t> padding = op.getPad();
1062 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1063 return op.emitOpError("expect all padding values to be >= 0, got ")
1064 << padding;
1065
1066 // Padding must be less than kernel size to avoid a divide-by-zero
1067 const int64_t kernelX = kernel[1];
1068 const int64_t padLeft = padding[2];
1069 const int64_t padRight = padding[3];
1070 if (padRight >= kernelX || padLeft >= kernelX)
1071 return op.emitOpError("expected left/right padding to be less than the "
1072 "width of the kernel, got pad_left=")
1073 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1074
1075 const int64_t kernelY = kernel[0];
1076 const int64_t padTop = padding[0];
1077 const int64_t padBottom = padding[1];
1078 if (padTop >= kernelY || padBottom >= kernelY)
1079 return op.emitOpError("expected top/bottom padding to be less than the "
1080 "height of the kernel, got pad_top=")
1081 << padTop << ", pad_bottom=" << padBottom
1082 << ", kernel_y=" << kernelY;
1083
1084 const auto inputType =
1085 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1086 const auto outputType =
1087 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1088 if (!inputType || !outputType)
1089 return success();
1090
1091 const auto verifyOutputSize =
1092 [&op](const int64_t inputSize, const int64_t outputSize,
1093 const int64_t kernelSize, const int64_t strideSize,
1094 const int64_t padBefore, const int64_t padAfter,
1095 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1096 const llvm::StringRef padBeforeName,
1097 const llvm::StringRef padAfterName) -> LogicalResult {
1098 if (ShapedType::isDynamic(inputSize))
1099 return success();
1100
1101 const std::optional<int64_t> calculatedOutSizeMinusOne =
1102 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1103 if (!calculatedOutSizeMinusOne.has_value())
1104 return op.emitOpError("expected input_")
1105 << dimName << " + pad_" << padBeforeName << " + pad_"
1106 << padAfterName << " - kernel_" << dimAxis
1107 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1108 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1109 << kernelSize << ") / " << strideSize;
1110
1111 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1112 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1113 return op.emitOpError("calculated output ")
1114 << dimName << " did not match expected: "
1115 << "calculated=" << calculatedOutSize
1116 << ", expected=" << outputSize;
1117
1118 return success();
1119 };
1120
1121 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1122 kernel[0], strides[0], padding[0], padding[1],
1123 "height", "y", "top", "bottom")))
1124 return failure();
1125
1126 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1127 kernel[1], strides[1], padding[2], padding[3],
1128 "width", "x", "left", "right")))
1129 return failure();
1130
1131 return success();
1132}
1133
1134LogicalResult tosa::AvgPool2dOp::verify() {
1135 if (failed(verifyPoolingOp(*this)))
1136 return failure();
1137
1138 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1139 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1140 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1141 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1142
1143 auto accType = getAccType();
1144 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1145 return emitOpError("accumulator type for integer tensor is not i32");
1146
1147 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1148 return emitOpError("accumulator type for f16 tensor is not f16/f32");
1149
1150 if (inputETy.isBF16() && !accType.isF32())
1151 return emitOpError("accumulator type for bf16 tensor is not f32");
1152
1153 if (inputETy.isF32() && !accType.isF32())
1154 return emitOpError("accumulator type for f32 tensor is not f32");
1155
1156 if (inputETy != inputZpETy)
1157 return emitOpError("expect both input and its zero point are the same "
1158 "element type, got ")
1159 << inputETy << " and " << inputZpETy;
1160
1161 if (resultETy != outputZpETy)
1162 return emitOpError("expect both output and its zero point are the same "
1163 "element type, got ")
1164 << resultETy << " and " << outputZpETy;
1165
1166 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1167 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1168 return failure();
1169
1170 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1171 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1172 return failure();
1173
1174 return success();
1175}
1176
1177LogicalResult tosa::ClampOp::verify() {
1178 mlir::Type inputETy =
1179 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1180 if (auto quantType =
1181 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1182 inputETy = quantType.getStorageType();
1183 }
1184 mlir::Type outputETy =
1185 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1186 if (auto quantType =
1187 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1188 outputETy = quantType.getStorageType();
1189 }
1190 if (inputETy != outputETy)
1191 return emitOpError("input/output element types are incompatible.");
1192
1193 auto maxValAttr = getMaxValAttr();
1194 auto minValAttr = getMinValAttr();
1195
1196 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1197
1198 if (inputETy.isInteger(dataTypeBitWidth)) {
1199 // if input datatype is integer, check that the min_val/max_val attributes
1200 // are integer attributes, and that their type is the same as the input's
1201 // datatype
1202 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1203 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1204 if (!intMaxValAttr || !intMinValAttr ||
1205 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1206 (intMaxValAttr.getType() != inputETy))
1207 return emitOpError("min/max attributes types are incompatible with "
1208 "input/output element types.");
1209
1210 const bool isUnsigned = inputETy.isUnsignedInteger();
1211 const bool isBoolean = inputETy.isInteger(1);
1212 const APInt minVal = intMinValAttr.getValue();
1213 const APInt maxVal = intMaxValAttr.getValue();
1214 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1215 return emitOpError("expected min_val <= max_val, got min_val=")
1216 << minValAttr << ", max_val=" << maxValAttr;
1217 } else {
1218 // otherwise, input datatype is float, check that the min_val/max_val
1219 // attributes share the same type and that their type is the same as the
1220 // input's datatype
1221 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1222 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1223 if (!floatMaxValAttr || !floatMinValAttr ||
1224 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1225 (floatMaxValAttr.getType() != inputETy))
1226 return emitOpError("min/max attributes types are incompatible with "
1227 "input/output element types.");
1228
1229 const APFloat minVal = floatMinValAttr.getValue();
1230 const APFloat maxVal = floatMaxValAttr.getValue();
1231 if (minVal.isNaN() || maxVal.isNaN())
1232 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1233 << minValAttr << ", max_val=" << maxValAttr;
1234
1235 if (maxVal < minVal)
1236 return emitOpError("expected min_val <= max_val, got min_val=")
1237 << minValAttr << ", max_val=" << maxValAttr;
1238 }
1239
1240 return success();
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// TOSA Operator Quantization Builders.
1245//===----------------------------------------------------------------------===//
1246
1247/// This builder is called on all convolution operators except TransposeConv,
1248/// which has specialized output shape semantics. The builder also defines the
1249/// bitwidth of the output given the bit width of the input & weight content.
1251 Type outputType, Value input, Value weight,
1252 Value bias, DenseI64ArrayAttr pad,
1253 DenseI64ArrayAttr stride,
1254 DenseI64ArrayAttr dilation,
1255 TypeAttr accType) {
1256 auto zps = createZPsAsConst(builder, input, weight);
1257 result.addOperands({input, weight, bias, zps.first, zps.second});
1258 result.addAttribute("pad", pad);
1259 result.addAttribute("stride", stride);
1260 result.addAttribute("dilation", dilation);
1261 result.addAttribute("acc_type", accType);
1262 Type finalOutputType = outputType;
1263 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1264 if (quantAttr) {
1265 finalOutputType =
1266 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1267 }
1268 result.addTypes(finalOutputType);
1269}
1270
1271/// Handles tosa.transpose_conv2d which has outpad and output shape
1272/// attributes.
1273static void
1275 Type outputType, Value input, Value weight,
1276 Value bias, DenseI64ArrayAttr outpad,
1277 DenseI64ArrayAttr stride, TypeAttr accType) {
1278 auto zps = createZPsAsConst(builder, input, weight);
1279 result.addOperands({input, weight, bias, zps.first, zps.second});
1280 result.addAttribute("out_pad", outpad);
1281 result.addAttribute("stride", stride);
1282 result.addAttribute("acc_type", accType);
1283 Type finalOutputType = outputType;
1284 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1285 if (quantAttr) {
1286 finalOutputType =
1287 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1288 }
1289 result.addTypes(finalOutputType);
1290}
1291
1292/// The tosa.matmul op is also intended to be generated where a fully_connected
1293/// op must be constructed where the weight is not a constant. In this case,
1294/// the fully_connected op must be expressed using matmul.
1295/// TODO: Add link to the leglization document explaining this.
1297 OperationState &result, Type outputType,
1298 Value a, Value b) {
1299 auto zps = createZPsAsConst(builder, a, b);
1300 result.addOperands({a, b, zps.first, zps.second});
1301
1302 Type finalOutputType{outputType};
1303 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1304 auto eType = getStorageElementTypeOrSelf(a.getType());
1305 auto inputBits = eType.getIntOrFloatBitWidth();
1306
1307 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1308 assert(outputShapedType && "Output must be a shaped type");
1309
1310 IntegerType accElementType;
1311 if (inputBits == 16)
1312 accElementType = builder.getIntegerType(48);
1313 else
1314 accElementType = builder.getI32Type();
1315
1316 finalOutputType = outputShapedType.clone(accElementType);
1317 }
1318 result.addTypes(finalOutputType);
1319}
1320
1321/// Both the tosa.avg_pool2d and unary ops use the same
1322/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1323/// has additional parameters not part of the unary ops.
1324static void
1326 Type outputType, Value input,
1327 DenseArrayAttr kernel, DenseArrayAttr stride,
1328 DenseArrayAttr pad, TypeAttr accType) {
1329 const Location loc{result.location};
1330 int64_t inputZp{0};
1331 int64_t outputZp{0};
1332
1333 if (auto quantAttr =
1334 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1335 inputZp = quantAttr.getInputZp();
1336 outputZp = quantAttr.getOutputZp();
1337 }
1338 const std::optional<Value> inputZpOp =
1339 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1340 if (!inputZpOp) {
1341 (void)emitError(
1342 loc,
1343 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1344 }
1345 const std::optional<Value> outputZpOp =
1346 createZeroPointTensor(builder, loc, outputType, outputZp);
1347 if (!outputZpOp) {
1348 (void)emitError(loc, "Failed to create output zero point tensor for "
1349 "quantized AVG_POOL2D op");
1350 }
1351
1352 if (inputZpOp && outputZpOp) {
1353 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1354 } else {
1355 // failed to create one or more zero points above: just add input as
1356 // operands this will trigger error in building the op because of missing
1357 // zero points
1358 result.addOperands({input});
1359 }
1360 result.addAttribute("kernel", kernel);
1361 result.addAttribute("stride", stride);
1362 result.addAttribute("pad", pad);
1363 result.addAttribute("acc_type", accType);
1364 result.types.push_back(outputType);
1365}
1366
1367/// This builder is called on single-parameter negate operator
1368/// to construct input and output zero points based on their
1369/// types.
1371 OperationState &result, Type outputType,
1372 Value input) {
1373 const Location loc{result.location};
1374 int64_t input1Zp{0};
1375 int64_t outputZp{0};
1376 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1377 if (quantAttr) {
1378 input1Zp = quantAttr.getInputZp();
1379 outputZp = quantAttr.getOutputZp();
1380 }
1381 const std::optional<Value> input1ZpOp =
1382 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1383 if (!input1ZpOp) {
1384 (void)emitError(
1385 loc, "Failed to create input1 zero point for quantized NEGATE op");
1386 }
1387
1388 const std::optional<Value> outputZpOp =
1389 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1390 if (!outputZpOp) {
1391 (void)emitError(
1392 loc, "Failed to create output zero point for quantized NEGATE op");
1393 }
1394
1395 if (input1ZpOp && outputZpOp) {
1396 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1397 } else {
1398 // failed to create one or more zero points above: just add input as
1399 // operands. This will trigger error in building the op because of
1400 // missing zero points
1401 result.addOperands({input});
1402 }
1403
1404 result.types.push_back(outputType);
1405}
1406
1407/// This builder is called on TOSA pad operator that needs to create its own
1408/// OptionalAttr quantization_attr parameter to scale the padding values
1409/// correctly. No pad_const is interpreted as zero-padding.
1411 Type outputType, Value input,
1412 Value paddings) {
1413 const Location loc{result.location};
1414 int32_t zp{0};
1415 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1416 if (quantAttr) {
1417 zp = static_cast<int32_t>(quantAttr.getInputZp());
1418 }
1419 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1420 result.addOperands({input, paddings, padConstOp});
1421 result.types.push_back(outputType);
1422}
1423
1425 StringRef name, Type variableType,
1426 Attribute initialValue) {
1427 const Location loc{result.location};
1428 auto nameAttr = builder.getStringAttr(name);
1429
1430 auto shapedType = dyn_cast<ShapedType>(variableType);
1431 if (!shapedType) {
1432 (void)emitError(loc, "variable type must be a shaped type");
1433 return;
1434 }
1435 if (!shapedType.hasRank()) {
1436 (void)emitError(loc, "variable type must be a ranked type");
1437 return;
1438 }
1439
1440 auto elementType = shapedType.getElementType();
1441 auto elementTypeAttr = TypeAttr::get(elementType);
1442 ArrayRef<int64_t> shape = shapedType.getShape();
1443 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1444
1445 result.addAttribute("sym_name", nameAttr);
1446 result.addAttribute("var_shape", varShapeAttr);
1447 result.addAttribute("type", elementTypeAttr);
1448 result.addAttribute("initial_value", initialValue);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// TOSA Operator Return Type Inference.
1453//===----------------------------------------------------------------------===//
1454
1455static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1456 SmallVector<int64_t> &outShape) {
1457 int64_t outRank = 0;
1458 for (int i = 0, e = operands.size(); i != e; ++i) {
1459 auto shape = operands.getShape(i);
1460 if (!shape.hasRank()) {
1461 // TODO(jennik): Update function to have better case handling for
1462 // invalid operands and for ranked tensors.
1463 return failure();
1464 }
1465 outRank = std::max<int64_t>(outRank, shape.getRank());
1466 }
1467
1468 outShape.resize(outRank, 1);
1469
1470 for (int i = 0, e = operands.size(); i != e; ++i) {
1471 auto shape = operands.getShape(i);
1472 auto rankDiff = outShape.size() - shape.getRank();
1473
1474 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1475 auto dim1 = outShape[i + rankDiff];
1476 auto dim2 = shape.getDimSize(i);
1477 auto resolvedDim = dim1;
1478
1479 if (dim1 == 1) {
1480 resolvedDim = dim2;
1481 } else if (dim2 == 1) {
1482 resolvedDim = dim1;
1483 } else if (dim1 != dim2) {
1484 return failure();
1485 }
1486 outShape[i + rankDiff] = resolvedDim;
1487 }
1488 }
1489
1490 return success();
1491}
1492
1493LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1494 MLIRContext *context, ::std::optional<Location> location,
1495 ArgMaxOp::Adaptor adaptor,
1496 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1497 ShapeAdaptor inputShape(adaptor.getInput().getType());
1498 IntegerAttr axis = adaptor.getProperties().axis;
1499 int32_t axisVal = axis.getValue().getSExtValue();
1500
1501 if (!inputShape.hasRank()) {
1502 inferredReturnShapes.push_back(ShapedTypeComponents());
1503 return success();
1504 }
1505
1506 SmallVector<int64_t> outShape;
1507 outShape.reserve(inputShape.getRank() - 1);
1508 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1509 if (i == axisVal)
1510 continue;
1511 outShape.push_back(inputShape.getDimSize(i));
1512 }
1513
1514 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1515 return success();
1516}
1517
1518LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1519 MLIRContext *context, ::std::optional<Location> location,
1520 RFFT2dOp::Adaptor adaptor,
1521 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1522 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1523
1524 if (!inputShape.hasRank())
1525 return failure();
1526
1527 llvm::SmallVector<int64_t> outputShape;
1528 outputShape.resize(3, ShapedType::kDynamic);
1529 outputShape[0] = inputShape.getDimSize(0);
1530 outputShape[1] = inputShape.getDimSize(1);
1531 int64_t inWidth = inputShape.getDimSize(2);
1532
1533 // Note that we can support this calculation symbolically
1534 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1535 if (inWidth != ShapedType::kDynamic)
1536 outputShape[2] = inWidth / 2 + 1;
1537
1538 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1539 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1540
1541 return success();
1542}
1543
1544static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1545 const llvm::StringRef dimName) {
1546 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1547 if (!isPowerOfTwo)
1548 return op->emitOpError("expected ")
1549 << dimName << " to be a power of two, got " << dimSize;
1550
1551 return success();
1552}
1553
1554LogicalResult tosa::RFFT2dOp::verify() {
1555 const auto outputTypes = getResultTypes();
1556 if (failed(verifyCompatibleShapes(outputTypes)))
1557 return emitOpError("expected output shapes to match, got ") << outputTypes;
1558
1559 const auto inputType =
1560 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1561 if (!inputType)
1562 return success();
1563
1564 const int64_t height = inputType.getDimSize(1);
1565 if (ShapedType::isStatic(height) &&
1566 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1567 return failure();
1568
1569 const int64_t width = inputType.getDimSize(2);
1570 if (ShapedType::isStatic(width) &&
1571 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1572 return failure();
1573
1574 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1575 if (!outputType)
1576 return success();
1577
1578 // Batch and height input/output dimensions should match
1579 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1580 outputType.getShape().drop_back())))
1581 return emitOpError("expected batch and height dimensions of input/output "
1582 "to match, got input=")
1583 << inputType << " output=" << outputType;
1584
1585 // Output width dimension expected to be input_width / 2 + 1
1586 const int64_t outputWidth = outputType.getDimSize(2);
1587 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1588 (outputWidth != (width / 2) + 1))
1589 return emitOpError(
1590 "expected output width to be equal to input_width / 2 + 1, got ")
1591 << outputWidth;
1592
1593 return success();
1594}
1595
1596LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1597 MLIRContext *context, ::std::optional<Location> location,
1598 FFT2dOp::Adaptor adaptor,
1599 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1600 inferredReturnShapes.push_back(
1601 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1602 inferredReturnShapes.push_back(
1603 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1604 return success();
1605}
1606
1607LogicalResult tosa::FFT2dOp::verify() {
1608 const auto inputRealType =
1609 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1610 const auto inputImagType =
1611 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1612 if (!inputRealType || !inputImagType)
1613 return success();
1614
1615 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1616 return ShapedType::isDynamic(a) ? a : b;
1617 };
1618
1619 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1620 inputImagType.getDimSize(1));
1621 if (ShapedType::isStatic(height) &&
1622 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1623 return failure();
1624
1625 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1626 inputImagType.getDimSize(2));
1627 if (ShapedType::isStatic(width) &&
1628 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1629 return failure();
1630
1631 return success();
1632}
1633
1634LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1635 MLIRContext *context, ::std::optional<Location> location,
1636 ConcatOp::Adaptor adaptor,
1637 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1638 // Infer all dimension sizes by reducing based on inputs.
1639 const Properties &prop = adaptor.getProperties();
1640 int32_t axis = prop.axis.getValue().getSExtValue();
1641 llvm::SmallVector<int64_t> outputShape;
1642 bool hasRankedInput = false;
1643 for (auto operand : adaptor.getOperands()) {
1644 ShapeAdaptor operandShape(operand.getType());
1645 if (!operandShape.hasRank())
1646 continue;
1647
1648 // Copy the Operand's rank.
1649 if (!hasRankedInput)
1650 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1651
1652 // Copy shapes until the dim is non-dynamic.
1653 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1654 if (i == axis || operandShape.isDynamicDim(i))
1655 continue;
1656 if (outputShape[i] == ShapedType::kDynamic)
1657 outputShape[i] = operandShape.getDimSize(i);
1658 if (outputShape[i] != operandShape.getDimSize(i))
1659 return emitOptionalError(location,
1660 "Cannot concat tensors with different sizes"
1661 " on the non-axis dimension ",
1662 i);
1663 }
1664
1665 hasRankedInput = true;
1666 }
1667
1668 if (adaptor.getInput1().empty())
1669 return failure();
1670
1671 Type inputType =
1672 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1673 if (!hasRankedInput) {
1674 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1675 return success();
1676 }
1677
1678 // Determine the dimension size along the concatenation axis.
1679 int64_t concatDimSize = 0;
1680 for (auto operand : adaptor.getOperands()) {
1681 ShapeAdaptor operandShape(operand.getType());
1682
1683 // We need to know the length of the concatenation axis of all inputs to
1684 // determine the dimension size of the output shape.
1685 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1686 concatDimSize = ShapedType::kDynamic;
1687 break;
1688 }
1689
1690 concatDimSize += operandShape.getDimSize(axis);
1691 }
1692
1693 outputShape[axis] = concatDimSize;
1694
1695 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1696 return success();
1697}
1698
1699LogicalResult tosa::ConcatOp::verify() {
1700 // check that each input has same element type as output
1701 auto outType = getOutput().getType();
1702 const Operation::operand_range inputList = getInput1();
1703
1704 // Check there is at least one input
1705 if (inputList.empty())
1706 return emitOpError("expect at least one input");
1707
1708 if (!llvm::all_of(inputList, [&](auto input) {
1709 return succeeded(verifySameElementTypes(
1710 *this, /* inType = */ input.getType(), outType));
1711 })) {
1712 return failure();
1713 }
1714
1715 const int32_t axis = getAxis();
1716 ShapeAdaptor firstRankedInputShape = nullptr;
1717 for (const auto &input : inputList) {
1718 const Type inputType = input.getType();
1719 ShapeAdaptor currShape(inputType);
1720 if (currShape.hasRank()) {
1721 firstRankedInputShape = currShape;
1722 // Check axis is in expected range
1723 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1724 return emitOpError("expect axis to be within range 0 < axis < "
1725 "rank(input1[firstRankedTensorIdx]), got ")
1726 << axis;
1727 break;
1728 }
1729 }
1730
1731 const auto allOperandsHasRank = [](const Value input) {
1732 return ShapeAdaptor(input.getType()).hasRank();
1733 };
1734 if (llvm::all_of(inputList, allOperandsHasRank)) {
1735 const int64_t firstInputRank = firstRankedInputShape.getRank();
1736
1737 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1738 const ShapeAdaptor inputShape(input.getType());
1739 const int64_t inputRank = inputShape.getRank();
1740 const size_t operandNum = index + 1;
1741
1742 // Check that each operand has the same rank
1743 if (inputRank != firstInputRank)
1744 return emitOpError(
1745 "expect all operands to have the same rank, but got ")
1746 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1747 << operandNum;
1748
1749 // Check non-axis dims match
1750 for (int i = 0; i < inputRank; i++) {
1751 const int64_t inputDim = inputShape.getDimSize(i);
1752 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1753 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1754 inputShape.isDynamicDim(i))
1755 continue;
1756 if (inputDim != firstInputDim)
1757 return emitOpError("expect all operand shapes to have the same sizes "
1758 "on non-axis dimensions, but got ")
1759 << inputDim << " vs " << firstInputDim << " at index " << i
1760 << " on operands 0 and " << operandNum;
1761 }
1762 }
1763
1764 // ERROR_IF(axis_sum != shape[axis]);
1765 int64_t axisSum = 0;
1766 for (const auto &input : inputList) {
1767 const ShapeAdaptor inputShape(input.getType());
1768 if (inputShape.isDynamicDim(axis)) {
1769 // make axisSum negative to indicate invalid value
1770 axisSum = -1;
1771 break;
1772 }
1773 axisSum += inputShape.getDimSize(axis);
1774 }
1775 const ShapeAdaptor outputShape(outType);
1776 if (axisSum >= 0 && outputShape.hasRank() &&
1777 !outputShape.isDynamicDim(axis) &&
1778 axisSum != outputShape.getDimSize(axis))
1779 return emitOpError("requires sum of axis dimensions of input1 "
1780 "equal to output axis dimension, got ")
1781 << axisSum << " and " << outputShape.getDimSize(axis);
1782 }
1783
1784 return success();
1785}
1786
1787LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1788 MLIRContext *context, ::std::optional<Location> location,
1789 ValueShapeRange operands, DictionaryAttr attributes,
1790 OpaqueProperties properties, RegionRange regions,
1791 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1792 auto elementType = IntegerType::get(context, /*width=*/1);
1793
1795 if (resolveBroadcastShape(operands, outShape).failed()) {
1796 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1797 return success();
1798 }
1799
1800 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1801 return success();
1802}
1803
1804bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1805 if (l.size() != r.size() || l.size() != 1)
1806 return false;
1807 return succeeded(verifyCompatibleShape(l[0], r[0]));
1808}
1809
1810LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1811 MLIRContext *context, ::std::optional<Location> location,
1812 MatMulOp::Adaptor adaptor,
1813 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1814 ShapeAdaptor lhsShape(adaptor.getA().getType());
1815 ShapeAdaptor rhsShape(adaptor.getB().getType());
1816
1817 // All shapes are dynamic.
1818 SmallVector<int64_t> outShape;
1819 outShape.resize(3, ShapedType::kDynamic);
1820
1821 if (lhsShape.hasRank()) {
1822 outShape[0] = lhsShape.getDimSize(0);
1823 outShape[1] = lhsShape.getDimSize(1);
1824 }
1825
1826 if (rhsShape.hasRank()) {
1827 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1828 : outShape[0];
1829 outShape[2] = rhsShape.getDimSize(2);
1830 }
1831
1832 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1833 return success();
1834}
1835
1836LogicalResult MatMulOp::verify() {
1837 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1838 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1839
1840 // Must be shaped tensor types
1841 if (!aType)
1842 return emitOpError("expect a shaped tensor for input a, got ")
1843 << getA().getType();
1844
1845 if (!bType)
1846 return emitOpError("expect a shaped tensor for input b, got ")
1847 << getB().getType();
1848
1849 auto aElementType = aType.getElementType();
1850 auto bElementType = bType.getElementType();
1851
1852 auto aQuantizedEType =
1853 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1854 auto bQuantizedEType =
1855 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1856
1857 if (aQuantizedEType || bQuantizedEType) {
1858 if (!aQuantizedEType || !bQuantizedEType) {
1859 return emitOpError("expect operands to be both quantized or both not "
1860 "quantized, got ")
1861 << aElementType << " and " << bElementType;
1862 }
1863 // both a and b have quantized element types
1864 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1865 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1866 if (aQuantWidth != bQuantWidth) {
1867 return emitOpError("expect quantized operands to have same widths, got ")
1868 << aQuantWidth << " and " << bQuantWidth;
1869 }
1870 }
1871
1872 // check a_zp and b_zp
1873 auto aEType = getStorageElementTypeOrSelf(aType);
1874 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1875 if (aEType != aZpEType) {
1876 return emitOpError("expect input a and a_zp have the same "
1877 "element type, got ")
1878 << aEType << " and " << aZpEType;
1879 }
1880
1881 auto bEType = getStorageElementTypeOrSelf(bType);
1882 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1883 if (bEType != bZpEType) {
1884 return emitOpError("expect input b and b_zp have the same "
1885 "element type, got ")
1886 << bEType << " and " << bZpEType;
1887 }
1888
1889 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1890 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1891 return failure();
1892
1893 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1894 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1895 return failure();
1896
1897 return success();
1898}
1899
1900LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1901 MLIRContext *context, ::std::optional<Location> location,
1902 MatmulTBlockScaledOp::Adaptor adaptor,
1903 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1904 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1905
1906 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1907 if (aDataShape.hasRank()) {
1908 outShape[0] = aDataShape.getDimSize(0);
1909 outShape[1] = aDataShape.getDimSize(1);
1910 }
1911
1912 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1913 if (aScaleShape.hasRank()) {
1914 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1915 : outShape[0];
1916 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1917 : outShape[1];
1918 }
1919
1920 // If B batch size is 1, it is broadcast across A's batch size
1921 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1922 if (bDataShape.hasRank()) {
1923 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1924 if (bDataBatchSize != 1)
1925 outShape[0] =
1926 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1927 outShape[2] = bDataShape.getDimSize(1);
1928 }
1929
1930 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1931 if (bScaleShape.hasRank()) {
1932 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1933 if (bScaleBatchSize != 1)
1934 outShape[0] =
1935 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1936 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1937 : outShape[2];
1938 }
1939
1940 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1941 return success();
1942}
1943
1944LogicalResult MatmulTBlockScaledOp::verify() {
1945 // Verify same input data types
1946 const Type aDataType = getAData().getType();
1947 const Type bDataType = getBData().getType();
1948 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1949 "B_data")))
1950 return failure();
1951
1952 auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
1953 const StringRef operandName,
1954 const StringRef dimName) -> LogicalResult {
1955 if (ShapedType::isDynamic(currDim)) {
1956 currDim = newDim;
1957 return success();
1958 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
1959 return emitOpError("expected ")
1960 << dimName << " of " << operandName << " to match size " << currDim
1961 << ", got " << newDim;
1962 }
1963 return success();
1964 };
1965
1966 // Verify input shape compatibility
1967 int64_t N = ShapedType::kDynamic;
1968 int64_t D = ShapedType::kDynamic;
1969 int64_t H = ShapedType::kDynamic;
1970 int64_t W = ShapedType::kDynamic;
1971 int64_t C = ShapedType::kDynamic;
1972 int64_t multiplesOfC = ShapedType::kDynamic;
1973
1974 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
1975 if (aDataShape.hasRank()) {
1976 N = aDataShape.getDimSize(0);
1977 H = aDataShape.getDimSize(1);
1978 C = aDataShape.getDimSize(2);
1979 }
1980
1981 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
1982 if (aScaleShape.hasRank()) {
1983 if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
1984 "batch")) ||
1985 failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
1986 "height")))
1987 return failure();
1988 multiplesOfC = aScaleShape.getDimSize(2);
1989 }
1990
1991 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
1992 if (bDataShape.hasRank()) {
1993 if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
1994 "batch")) ||
1995 failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
1996 "channels")))
1997 return failure();
1998 W = bDataShape.getDimSize(1);
1999 }
2000
2001 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2002 if (bScaleShape.hasRank()) {
2003 if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
2004 "batch")) ||
2005 failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
2006 "width")) ||
2007 failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
2008 "b_scale", "C/block_size")))
2009 return failure();
2010 }
2011
2012 // Verify batch size is broadcast compatible
2013 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2014 return emitOpError("expect B matrix batch size to be broadcast compatible "
2015 "with A, got D=")
2016 << D << " vs N=" << N;
2017
2018 // Verify C is a multiple of block size
2019 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2020 if (ShapedType::isStatic(C) && C % blockSize != 0)
2021 return emitOpError("expect C to be a multiple of block size, got C=")
2022 << C << ", block_size=" << blockSize;
2023
2024 // Verify multiplesOfC is C / block size
2025 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2026 multiplesOfC != C / blockSize)
2027 return emitOpError(
2028 "expect scale operands dimension 2 to equal C/block_size (")
2029 << C << "/" << blockSize << ")"
2030 << ", got " << multiplesOfC;
2031
2032 // Verify output shape
2033 N = ShapedType::isDynamic(N) ? D : N;
2034 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2035 const auto outputType = cast<ShapedType>(getResult().getType());
2036 if (outputType.hasRank() &&
2037 failed(
2038 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2039 InFlightDiagnostic opError = emitOpError("expected output shape ");
2040 auto stringifyDim = [&](int64_t d) {
2041 if (ShapedType::isDynamic(d))
2042 opError << "?";
2043 else
2044 opError << d;
2045 };
2046 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2047 opError << " to be compatible with expected output shape ";
2048 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2049 return opError;
2050 }
2051
2052 return success();
2053}
2054
2055LogicalResult tosa::PadOp::inferReturnTypeComponents(
2056 MLIRContext *context, ::std::optional<Location> location,
2057 PadOp::Adaptor adaptor,
2058 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2059 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2060 auto paddingRank =
2061 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2062 SmallVector<int64_t> outputShape;
2063
2064 // If the input rank is unknown, we can infer the output rank using the
2065 // padding shape's rank divided by 2.
2066 if (!inputShape.hasRank()) {
2067 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2068 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2069 return success();
2070 }
2071
2072 SmallVector<int64_t> paddingValues;
2073 // If the paddings value is not a constant, all dimensions must be dynamic.
2074 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2075 paddingValues)) {
2076 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2077 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2078 return success();
2079 }
2080
2081 outputShape.reserve(inputShape.getRank());
2082 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2083 if (inputShape.isDynamicDim(i)) {
2084 outputShape.push_back(ShapedType::kDynamic);
2085 continue;
2086 }
2087 auto padFront = paddingValues[i * 2];
2088 auto padBack = paddingValues[i * 2 + 1];
2089 if (padFront < 0 || padBack < 0) {
2090 // if either padding for dim i is -1, output dim is unknown
2091 outputShape.push_back(ShapedType::kDynamic);
2092 continue;
2093 }
2094
2095 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2096 }
2097
2098 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2099 return success();
2100}
2101
2102LogicalResult tosa::PadOp::verify() {
2103 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2104 /* outType = */ getOutput().getType())
2105 .failed()) {
2106 return failure();
2107 }
2108
2109 if (auto padConst = getPadConst()) {
2110 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2111 /* outType = */ getOutput().getType())
2112 .failed()) {
2113 return failure();
2114 }
2115 }
2116
2117 RankedTensorType inputType =
2118 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2119 RankedTensorType outputType =
2120 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2121 if (!inputType || !outputType)
2122 return success();
2123
2124 auto inputRank = inputType.getRank();
2125 auto outputRank = outputType.getRank();
2126 if (inputRank != outputRank)
2127 return emitOpError() << "expect same input and output tensor rank, but got "
2128 << "inputRank: " << inputRank
2129 << ", outputRank: " << outputRank;
2130
2131 DenseIntElementsAttr paddingAttr;
2132 if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
2133 return failure();
2134 }
2135
2136 auto paddingValues = paddingAttr.getValues<APInt>();
2137 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2138 return emitOpError() << "padding tensor must have " << inputRank
2139 << " * 2 = " << inputRank * 2 << " elements, but got "
2140 << paddingValues.size();
2141
2142 auto inputShape = inputType.getShape();
2143 auto outputShape = outputType.getShape();
2144
2145 for (int64_t i = 0; i < inputRank; ++i) {
2146 int64_t padStart = paddingValues[i * 2].getSExtValue();
2147 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2148
2149 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2150 return emitOpError()
2151 << "invalid padding values at dimension " << i
2152 << ": values must be non-negative or -1 for dynamic padding, got ["
2153 << padStart << ", " << padEnd << "]";
2154 }
2155
2156 // Skip shape verification for dynamic input/output
2157 if (inputShape[i] == ShapedType::kDynamic ||
2158 outputShape[i] == ShapedType::kDynamic)
2159 continue;
2160
2161 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2162 return emitOpError() << "mismatch in output shape at dimension " << i
2163 << ": expected " << inputShape[i] << " + "
2164 << padStart << " + " << padEnd << " = "
2165 << (inputShape[i] + padStart + padEnd)
2166 << ", but got " << outputShape[i];
2167 }
2168 }
2169
2170 return success();
2171}
2172
2173LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2174 MLIRContext *context, ::std::optional<Location> location,
2175 SliceOp::Adaptor adaptor,
2176 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2177
2178 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2181
2182 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2183 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2184 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2185 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2186 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2187 return success();
2188 }
2189
2190 // if size[i] is -1, all remaining elements in dimension i are included
2191 // in the slice, similar to TF.
2192 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2193 // initialize outputShape to all unknown
2194 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2195 if (inputShape.hasRank()) {
2196 for (size_t i = 0; i < size.size(); i++) {
2197 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2198 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2199 start[i] < inputShape.getDimSize(i))) {
2200 // size[i] is not 0 and not < -1, and start[i] is in valid range
2201 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2202 // input shape has unknown dim[i] - only valid if size[i] > 0
2203 if (size[i] > 0) {
2204 outputShape[i] = size[i];
2205 }
2206 } else {
2207 // input shape has known dim[i]
2208 if (size[i] == -1) {
2209 outputShape[i] = inputShape.getDimSize(i) - start[i];
2210 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2211 // start[i] + size[i] is within bound of input shape's dim[i]
2212 outputShape[i] = size[i];
2213 }
2214 }
2215 }
2216 }
2217 } else {
2218 outputShape = convertToMlirShape(size);
2219 }
2220 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2221 return success();
2222}
2223
2224LogicalResult tosa::SliceOp::verify() {
2225 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2226 /* outType = */ getOutput().getType())
2227 .failed())
2228 return failure();
2229
2230 const ShapeAdaptor inputShape(getInput1().getType());
2231 if (inputShape.hasRank()) {
2232 const auto inputRank = inputShape.getRank();
2233 const ShapeAdaptor outputShape(getOutput().getType());
2234 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2235 return emitOpError(
2236 "expect input1 and output to have the same ranks, got ")
2237 << inputRank << " and " << outputShape.getRank();
2238
2239 const auto startShapeRank =
2240 llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2241 if (inputRank != startShapeRank)
2242 return emitOpError("length of start is not equal to rank of input shape");
2243
2244 const auto sizeShapeRank =
2245 llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2246 if (inputRank != sizeShapeRank)
2247 return emitOpError("length of size is not equal to rank of input shape");
2248 }
2249
2250 return success();
2251}
2252
2253LogicalResult tosa::MulOp::inferReturnTypeComponents(
2254 MLIRContext *context, ::std::optional<Location> location,
2255 ValueShapeRange operands, DictionaryAttr attributes,
2256 OpaqueProperties properties, RegionRange regions,
2257 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2258 // mul op's output shape only depend on input1 and input2, not on shift
2259 ValueShapeRange twoInputs = operands.drop_back();
2261 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2262 inferredReturnShapes.push_back(ShapedTypeComponents());
2263 } else {
2264 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2265 }
2266 return success();
2267}
2268
2269LogicalResult tosa::MulOp::verify() {
2270 const Value output = getOutput();
2271 auto resElemType = getElementTypeOrSelf(output);
2272
2273 // Verify if the element type among operands and result match tosa
2274 // specification.
2275 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2276 IntegerType lhsIntType =
2277 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2278 IntegerType rhsIntType =
2279 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2280 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2281 return emitOpError("requires the same element type for all operands");
2282
2283 // Though the spec requires the element type of result to be i32, a more
2284 // relaxed way is provided at dialect level for easier cooperating with
2285 // other dialects.
2286 if (lhsIntType.getWidth() > resIntType.getWidth())
2287 return emitOpError("invalid data type size for operands or result");
2288
2289 } else {
2290 // For other supported type, the spec requires requires the same element
2291 // type for all operands (excludes `shift` operand) and results.
2292 for (int i = 0; i < 2; ++i) {
2293 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2294 return emitOpError(
2295 "requires the same element type for all operands and results");
2296 }
2297
2298 // verify shift has value 0 for non-integer types
2299 ElementsAttr shift_elem;
2300 if (matchPattern(getShift(), m_Constant(&shift_elem))) {
2301 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2302 if (shift != 0) {
2303 return emitOpError() << "require shift to be 0 for float type";
2304 }
2305 }
2306 }
2307
2308 // Verify the op has same ranks for all main operands (excludes extra operands
2309 // such as shift of mul op, so this is the only difference with the built-in
2310 // `SameOperandsAndResultRank` trait) and results types, if known.
2311 TypeRange operandTypes = getOperandTypes();
2312 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2313 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2314
2315 const bool aHasRank = aType.hasRank();
2316 const bool bHasRank = bType.hasRank();
2317 if (aHasRank && bHasRank) {
2318 const int64_t aRank = aType.getRank();
2319 const int64_t bRank = bType.getRank();
2320 if (aRank != bRank)
2321 return emitOpError("a and b operands don't have matching ranks, got ")
2322 << aRank << " and " << bRank;
2323
2324 // check for broadcast compatible shapes
2325 SmallVector<int64_t> resultShape;
2327 aType.getShape(), bType.getShape(), resultShape))
2328 return emitOpError("a and b operands don't have broadcast-compatible "
2329 "shapes, got ")
2330 << aType << " and " << bType;
2331 }
2332
2333 ShapedType resultType = cast<ShapedType>(output.getType());
2334 if (!resultType.hasRank())
2335 return success();
2336
2337 const int64_t resultRank = resultType.getRank();
2338 if (aHasRank && resultRank != aType.getRank())
2339 return emitOpError("result type has different rank than a, got ")
2340 << resultRank << " vs " << aType.getRank();
2341 if (bHasRank && resultRank != bType.getRank())
2342 return emitOpError("result type has different rank than b, got ")
2343 << resultRank << " vs " << bType.getRank();
2344
2345 return success();
2346}
2347
2348LogicalResult tosa::TableOp::inferReturnTypeComponents(
2349 MLIRContext *context, ::std::optional<Location> location,
2350 TableOp::Adaptor adaptor,
2351 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2352 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2353
2354 if (!inputShape.hasRank()) {
2355 inferredReturnShapes.push_back(ShapedTypeComponents());
2356 return success();
2357 }
2358
2359 inferredReturnShapes.resize(1);
2360 inputShape.getDims(inferredReturnShapes[0]);
2361 return success();
2362}
2363
2364LogicalResult tosa::TableOp::verify() {
2365 const TensorType inputType = getInput1().getType();
2366 const TensorType outputType = getOutput().getType();
2367
2368 if (!inputType.hasRank() || !outputType.hasRank())
2369 return success();
2370
2371 if (inputType.getRank() != outputType.getRank())
2372 return emitOpError()
2373 << "expected input tensor rank to equal result tensor rank";
2374
2375 auto inputDims = inputType.getShape();
2376 auto outputDims = outputType.getShape();
2377 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2378 int64_t dim = it.index();
2379 auto [inputDim, outputDim] = it.value();
2380 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2381 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2382 << " doesn't match dim(input, " << dim
2383 << ") = " << inputDim;
2384 }
2385 }
2386 return success();
2387}
2388
2389LogicalResult
2390tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2391 // Multiples must be constants.
2392 DenseIntElementsAttr multiplesAttr;
2393 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2394 return failure();
2395 multiples = llvm::to_vector(
2396 llvm::map_range(multiplesAttr.getValues<APInt>(),
2397 [](const APInt &val) { return val.getSExtValue(); }));
2398 return success();
2399}
2400
2401LogicalResult tosa::TileOp::inferReturnTypeComponents(
2402 MLIRContext *context, ::std::optional<Location> location,
2403 TileOp::Adaptor adaptor,
2404 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2405 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2406 SmallVector<int64_t> multiples;
2407 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2408 multiples)) {
2409 auto rank =
2410 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2411 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2412 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2413 return success();
2414 } else {
2415 multiples = convertToMlirShape(multiples);
2416 }
2417
2418 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2419 SmallVector<int64_t> outputShape;
2420 if (!inputShape.hasRank()) {
2421 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2422 inferredReturnShapes.push_back(
2423 ShapedTypeComponents(outputShape, inputType));
2424 return success();
2425 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2426 return failure();
2427
2428 // Any non dynamic dimension can be multiplied to a known size.
2429 outputShape.reserve(multiples.size());
2430 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2431 if (multiples[i] == ShapedType::kDynamic) {
2432 outputShape.push_back(ShapedType::kDynamic);
2433 } else {
2434 int64_t dim = inputShape.getDimSize(i);
2435 if (dim != ShapedType::kDynamic)
2436 dim *= multiples[i];
2437 outputShape.push_back(dim);
2438 }
2439 }
2440
2441 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2442 return success();
2443}
2444
2445LogicalResult tosa::TileOp::verify() {
2446 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2447 /* outType = */ getOutput().getType())
2448 .failed()) {
2449 return failure();
2450 }
2451 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2452 ShapedType outputType = llvm::cast<ShapedType>(getType());
2453
2454 shapeType multiplesType =
2455 llvm::cast<tosa::shapeType>(getMultiples().getType());
2456
2457 auto multiplesRank = multiplesType.getRank();
2458
2459 if (inputType.hasRank()) {
2460 if (inputType.getRank() != multiplesRank)
2461 return emitOpError("expect 'multiples' to have rank ")
2462 << inputType.getRank() << " but got " << multiplesRank << ".";
2463 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2464 return emitOpError("expect same input and output tensor rank.");
2465 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2466 return emitOpError("expect 'multiples' array to have length ")
2467 << outputType.getRank() << " but got " << multiplesRank << ".";
2468
2469 SmallVector<int64_t> multiples;
2470 if (getConstantMultiples(multiples).succeeded() &&
2471 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2472 return emitOpError(
2473 "expect element of 'multiples' to be positive integer or -1.");
2474
2475 return success();
2476}
2477
2478bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2479 if (l.size() != r.size() || l.size() != 1)
2480 return false;
2481 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2482}
2483
2484LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2485 MLIRContext *context, ::std::optional<Location> location,
2486 ReshapeOp::Adaptor adaptor,
2487 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2488 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2489 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2490 llvm::SmallVector<int64_t> newShapeValue;
2491 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2492 newShapeValue)) {
2493 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2494 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2495 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2496 return success();
2497 } else {
2498 newShapeValue = convertToMlirShape(newShapeValue);
2499 }
2500
2501 // We cannot infer from the total number of elements so we must take the
2502 // shape attribute as exact.
2503 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2504 inferredReturnShapes.push_back(
2505 ShapedTypeComponents(newShapeValue, inputType));
2506 return success();
2507 }
2508
2509 // Determine the number of elements covered by the slice of all static
2510 // dimensions. This allows us to infer the length of the remaining dynamic
2511 // dimension.
2512 int64_t numElements = inputShape.getNumElements();
2513 int64_t staticMul = 1;
2514 for (auto val : newShapeValue) {
2515 if (ShapedType::isStatic(val)) {
2516 staticMul *= val;
2517 }
2518 }
2519
2520 // Determine the length of the dynamic dimension.
2521 for (auto &val : newShapeValue) {
2522 if (ShapedType::isDynamic(val))
2523 val = numElements / staticMul;
2524 }
2525
2526 inferredReturnShapes.push_back(
2527 ShapedTypeComponents(newShapeValue, inputType));
2528 return success();
2529}
2530
2531llvm::LogicalResult tosa::ReshapeOp::verify() {
2532 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2533 /* outType = */ getOutput().getType())
2534 .failed()) {
2535 return failure();
2536 }
2537 TensorType inputType = getInput1().getType();
2538
2539 SmallVector<int64_t> shapeValues;
2540 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2541 // skip following checks if shape is not constant
2542 return mlir::success();
2543 }
2544
2545 int missingDims = llvm::count(shapeValues, -1);
2546 if (missingDims > 1)
2547 return emitOpError() << "expected at most one target dimension to be -1";
2548
2549 const auto outputType = dyn_cast<RankedTensorType>(getType());
2550 if (!outputType)
2551 return success();
2552
2553 if ((int64_t)shapeValues.size() != outputType.getRank())
2554 return emitOpError() << "new shape does not match result rank";
2555
2556 for (auto [newShapeDim, outputShapeDim] :
2557 zip(shapeValues, outputType.getShape())) {
2558 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2559 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2560 return emitOpError() << "new shape is inconsistent with result shape";
2561
2562 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2563 return emitOpError() << "new shape has invalid tensor dimension size "
2564 << newShapeDim;
2565 }
2566
2567 if (inputType.hasStaticShape()) {
2568 int64_t inputElementsNum = inputType.getNumElements();
2569 if (outputType.hasStaticShape()) {
2570 int64_t outputElementsNum = outputType.getNumElements();
2571 if (inputElementsNum != outputElementsNum) {
2572 return emitOpError() << "cannot reshape " << inputElementsNum
2573 << " elements into " << outputElementsNum;
2574 }
2575 }
2576
2577 int64_t newShapeElementsNum =
2578 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2579 return (dim > 0) ? acc * dim : acc;
2580 });
2581 bool isStaticNewShape =
2582 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2583 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2584 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2585 return emitOpError() << "cannot reshape " << inputElementsNum
2586 << " elements into " << newShapeElementsNum;
2587 }
2588 }
2589
2590 return mlir::success();
2591}
2592
2593// return failure if val is not a constant
2594// set zp to -1 if val is non-zero float or val is not integer nor float
2595// otherwise set zp to val's constant value
2596static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2597 ElementsAttr zpAttr;
2598 if (!matchPattern(val, m_Constant(&zpAttr))) {
2599 return failure();
2600 }
2601
2602 Type zpElemType = zpAttr.getElementType();
2603
2604 if (llvm::isa<FloatType>(zpElemType)) {
2605 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2606 return 0;
2607 }
2608 // return non-zero value to trigger error check
2609 return -1;
2610 }
2611
2612 if (llvm::isa<IntegerType>(zpElemType)) {
2613 if (signExtend)
2614 return zpAttr.getValues<APInt>()[0].getSExtValue();
2615 else
2616 return zpAttr.getValues<APInt>()[0].getZExtValue();
2617 }
2618
2619 // return non-zero value to trigger error check
2620 return -1;
2621}
2622
2623template <typename T>
2624static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2625 const std::string &operand) {
2626 Type zpElemType = getElementTypeOrSelf(val);
2627
2628 if (!zpElemType.isInteger(8) && zp != 0) {
2629 // convert operand to lower case for error message
2630 std::string lower = operand;
2631 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2632 return op.emitOpError()
2633 << lower << " zero point must be zero for non-int8 integer types";
2634 }
2635
2636 return success();
2637}
2638
2639static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2640 const int64_t &zp,
2641 const std::string &operand) {
2642 bool isInputZp = (operand == "Input");
2643
2644 bool tensorUnsigned =
2645 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2646 StringRef tensorName = isInputZp ? "input" : "output";
2647
2648 Type zpElemType = getElementTypeOrSelf(zpVal);
2649
2650 if (zp != 0) {
2651 if (!zpElemType.isInteger(8) &&
2652 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2653 return op.emitOpError()
2654 << "expect " << tensorName << "_zp of 0, got " << zp;
2655 }
2656 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2657 return op.emitOpError() << "expect " << tensorName
2658 << "_zp of 0 or 32768 for unsigned int16 "
2659 << tensorName << ", got " << zp;
2660 }
2661 }
2662
2663 return success();
2664}
2665
2666#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2667 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2668 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2669 } \
2670 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2671 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2672 }
2673
2674ZERO_POINT_HELPER(Conv2DOp, Input, true)
2675ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2676ZERO_POINT_HELPER(Conv3DOp, Input, true)
2677ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2678ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2679ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2680ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2681ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2682ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2683ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2684ZERO_POINT_HELPER(MatMulOp, A, true)
2685ZERO_POINT_HELPER(MatMulOp, B, true)
2686ZERO_POINT_HELPER(NegateOp, Input1, true)
2687ZERO_POINT_HELPER(NegateOp, Output, true)
2688ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2689ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2690#undef ZERO_POINT_HELPER
2691
2692LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2693 MLIRContext *context, ::std::optional<Location> location,
2694 TransposeOp::Adaptor adaptor,
2695 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2696 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2697
2698 // If input rank and permutation length is unknown, the output rank is
2699 // unknown.
2700 if (!inputShape.hasRank()) {
2701 inferredReturnShapes.push_back(ShapedTypeComponents());
2702 return success();
2703 }
2704
2705 const auto inputRank = inputShape.getRank();
2706
2707 // This would imply the number of permutations does not match the rank of
2708 // the input which is illegal.
2709 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2710 return failure();
2711 }
2712
2713 SmallVector<int64_t> outputShape;
2714 // Rank-0 means no permutations matter.
2715 if (inputRank == 0) {
2716 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2717 return success();
2718 }
2719
2720 // Check whether the input dimensions are all the same.
2721 bool allTheSame = true;
2722 for (int i = 1, s = inputRank; i < s; i++) {
2723 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2724 allTheSame = false;
2725 break;
2726 }
2727 }
2728
2729 // If all of the input dimensions are the same we don't care about the
2730 // permutation.
2731 if (allTheSame) {
2732 outputShape.resize(inputRank, inputShape.getDimSize(0));
2733 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2734 return success();
2735 }
2736
2737 outputShape.resize(inputRank, ShapedType::kDynamic);
2738
2739 // Constant permutation values must be within the input rank.
2740 if (llvm::any_of(adaptor.getPerms(),
2741 [inputRank](const auto i) { return i >= inputRank; }))
2742 return failure();
2743
2744 outputShape.reserve(inputRank);
2745 for (int i = 0, s = inputRank; i < s; i++) {
2746 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2747 }
2748
2749 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2750 return success();
2751}
2752
2753LogicalResult tosa::TransposeOp::verify() {
2754 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2755 /* outType = */ getOutput().getType())
2756 .failed()) {
2757 return failure();
2758 }
2759
2760 const ShapeAdaptor inputShape(getInput1().getType());
2761 const ShapeAdaptor outputShape(getOutput().getType());
2762
2763 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2764
2765 if (inputShape.hasRank() &&
2766 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2767 return emitOpError() << "expected perms attribute to have size "
2768 << inputShape.getRank()
2769 << " (input rank) but got size "
2770 << constantPerms.size();
2771
2772 if (inputShape.hasRank() && outputShape.hasRank() &&
2773 inputShape.getRank() != outputShape.getRank())
2774 return emitOpError()
2775 << "expected input tensor rank to equal result tensor rank";
2776
2777 if (outputShape.hasRank() &&
2778 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2779 return emitOpError() << "expected perms attribute to have size "
2780 << outputShape.getRank()
2781 << " (output rank) but got size "
2782 << constantPerms.size();
2783
2784 if (!llvm::all_of(constantPerms,
2785 [&constantPerms](int32_t s) {
2786 return s >= 0 &&
2787 static_cast<size_t>(s) < constantPerms.size();
2788 }) ||
2789 !isPermutationVector(llvm::to_vector(llvm::map_range(
2790 constantPerms, [](int32_t v) -> int64_t { return v; }))))
2791 return emitOpError() << "expected valid permutation indices";
2792
2793 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2794 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2795 inputShape.getNumElements() != outputShape.getNumElements())
2796 return emitOpError() << "expected input1 and output to have same numbers "
2797 "of elements, got "
2798 << inputShape.getNumElements() << " and "
2799 << outputShape.getNumElements();
2800
2801 // Verify that the types of the input and output tensors are properly
2802 // permuted.
2803 if (inputShape.hasRank() && outputShape.hasRank()) {
2804 for (auto i = 0; i < outputShape.getRank(); i++) {
2805 if (inputShape.isDynamicDim(constantPerms[i]) ||
2806 outputShape.isDynamicDim(i))
2807 continue;
2808
2809 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2810 return emitOpError()
2811 << "expected output tensor dim " << i << " to match "
2812 << "input dim " << constantPerms[i] << " with value of "
2813 << inputShape.getDimSize(constantPerms[i]);
2814 }
2815 }
2816
2817 return success();
2818}
2819
2820LogicalResult TransposeOp::reifyResultShapes(
2821 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2822
2823 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2824
2825 Value input = getInput1();
2826 auto inputType = cast<TensorType>(input.getType());
2827
2828 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2829 for (auto dim : transposePerms) {
2830 int32_t dimInInput = transposePerms[dim];
2831 if (inputType.isDynamicDim(dimInInput))
2832 returnedDims[dim] =
2833 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2834 .getResult();
2835 else
2836 returnedDims[dim] =
2837 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2838 }
2839
2840 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2841 return success();
2842}
2843
2844LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2845 MLIRContext *context, ::std::optional<Location> location,
2846 GatherOp::Adaptor adaptor,
2847 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2848 llvm::SmallVector<int64_t> outputShape;
2849 outputShape.resize(3, ShapedType::kDynamic);
2850
2851 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2852 if (valuesShape.hasRank()) {
2853 outputShape[0] = valuesShape.getDimSize(0);
2854 outputShape[2] = valuesShape.getDimSize(2);
2855 }
2856
2857 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2858 if (indicesShape.hasRank()) {
2859 if (outputShape[0] == ShapedType::kDynamic)
2860 outputShape[0] = indicesShape.getDimSize(0);
2861 if (outputShape[1] == ShapedType::kDynamic)
2862 outputShape[1] = indicesShape.getDimSize(1);
2863 }
2864
2865 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2866 return success();
2867}
2868
2869LogicalResult tosa::GatherOp::verify() {
2870 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2871 /* outType = */ getOutput().getType())
2872 .failed()) {
2873 return failure();
2874 }
2875
2876 const ShapeAdaptor valuesShape(getValues().getType());
2877 const ShapeAdaptor indicesShape(getIndices().getType());
2878 const ShapeAdaptor outputShape(getOutput().getType());
2879
2880 int64_t N = ShapedType::kDynamic;
2881 int64_t W = ShapedType::kDynamic;
2882 int64_t C = ShapedType::kDynamic;
2883
2884 if (valuesShape.hasRank()) {
2885 N = valuesShape.getDimSize(0);
2886 C = valuesShape.getDimSize(2);
2887 }
2888 if (indicesShape.hasRank()) {
2889 const int64_t indicesN = indicesShape.getDimSize(0);
2890 W = indicesShape.getDimSize(1);
2891 if (N == ShapedType::kDynamic)
2892 N = indicesN;
2893 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2894 return emitOpError() << "requires indices dimension 0 to have size " << N
2895 << ", got " << indicesN;
2896 }
2897 if (outputShape.hasRank()) {
2898 const int64_t outputN = outputShape.getDimSize(0);
2899 const int64_t outputW = outputShape.getDimSize(1);
2900 const int64_t outputC = outputShape.getDimSize(2);
2901 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2902 N != outputN)
2903 return emitOpError() << "requires output dimension 0 to have size " << N
2904 << ", got " << outputN;
2905
2906 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2907 W != outputW)
2908 return emitOpError() << "requires output dimension 1 to have size " << W
2909 << ", got " << outputW;
2910 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2911 C != outputC)
2912 return emitOpError() << "requires output dimension 2 to have size " << C
2913 << ", got " << outputC;
2914 }
2915 return success();
2916}
2917
2918LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2919 MLIRContext *context, ::std::optional<Location> location,
2920 ResizeOp::Adaptor adaptor,
2921 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2922 llvm::SmallVector<int64_t, 4> outputShape;
2923 outputShape.resize(4, ShapedType::kDynamic);
2924
2925 ShapeAdaptor inputShape(adaptor.getInput().getType());
2926 if (!inputShape.hasRank())
2927 return failure();
2928
2929 outputShape[0] = inputShape.getDimSize(0);
2930 outputShape[3] = inputShape.getDimSize(3);
2931 int64_t inputHeight = inputShape.getDimSize(1);
2932 int64_t inputWidth = inputShape.getDimSize(2);
2933
2934 if ((inputHeight == ShapedType::kDynamic) ||
2935 (inputWidth == ShapedType::kDynamic))
2936 return failure();
2937
2938 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2939 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2940 scaleInt) ||
2941 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2942 offsetInt) ||
2943 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2944 borderInt)) {
2945 return failure();
2946 }
2947
2948 // Compute the output shape based on attributes: scale, offset, and border.
2949 const int64_t outputHeight =
2950 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2951 scaleInt[1]) +
2952 1;
2953
2954 const int64_t outputWidth =
2955 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2956 scaleInt[3]) +
2957 1;
2958
2959 if (outputHeight < 0 || outputWidth < 0) {
2960 return emitOptionalError(
2961 location,
2962 "calculated output height and width must be non-negative, "
2963 "got height = ",
2964 outputHeight, ", width = ", outputWidth);
2965 }
2966
2967 outputShape[1] = outputHeight;
2968 outputShape[2] = outputWidth;
2969 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2970 return success();
2971}
2972
2973LogicalResult tosa::ResizeOp::verify() {
2974 const Value input = getInput();
2975 const Value output = getOutput();
2976 const RankedTensorType inputType =
2977 llvm::dyn_cast<RankedTensorType>(input.getType());
2978 const RankedTensorType outputType =
2979 llvm::dyn_cast<RankedTensorType>(output.getType());
2980
2981 SmallVector<int64_t> scaleValues;
2982 SmallVector<int64_t> offsetValues;
2983 SmallVector<int64_t> borderValues;
2984 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2985 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2986 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2987 // Skip following checks if shape is not constant
2988 return success();
2989 }
2990
2991 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2992 return emitOpError("expect all scale values to be > 0, got ")
2993 << scaleValues;
2994
2995 const int64_t scaleYN = scaleValues[0];
2996 const int64_t scaleYD = scaleValues[1];
2997 const int64_t scaleXN = scaleValues[2];
2998 const int64_t scaleXD = scaleValues[3];
2999
3000 const int64_t offsetY = offsetValues[0];
3001 const int64_t offsetX = offsetValues[1];
3002
3003 const int64_t borderY = borderValues[0];
3004 const int64_t borderX = borderValues[1];
3005
3006 if (!inputType)
3007 return success();
3008 if (!outputType)
3009 return success();
3010
3011 const int64_t oh = outputType.getDimSize(1);
3012 const int64_t ow = outputType.getDimSize(2);
3013 const int64_t ih = inputType.getDimSize(1);
3014 const int64_t iw = inputType.getDimSize(2);
3015
3016 // Don't check with input height that could be broadcast (ih != 1)
3017 // since Linalg, a consumer of TOSA, expects broadcasting support
3018 // in resize to be available. Taking the cautious approach for now,
3019 // we can consider removing support for broadcasting later.
3020 if (ih != ShapedType::kDynamic && ih != 1) {
3021 const std::optional<int64_t> calculatedOutHeightMinusOne =
3022 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3023 if (!calculatedOutHeightMinusOne.has_value())
3024 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3025 "border_y ")
3026 << "to be wholly divisible by scale_y_d, got ((" << ih
3027 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3028 << ") / " << scaleYD;
3029 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3030 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3031 return emitOpError("calculated output height did not match expected: ")
3032 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3033 }
3034
3035 // Don't check with input width that could be broadcast (iw != 1)
3036 // since Linalg, a consumer of TOSA, expects broadcasting support
3037 // in resize to be available. Taking the cautious approach for now,
3038 // we can consider removing support for broadcasting later.
3039 if (iw != ShapedType::kDynamic && iw != 1) {
3040 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3041 const std::optional<int64_t> calculatedOutWidthMinusOne =
3042 idivCheck(scaledInWidth, scaleXD);
3043 if (!calculatedOutWidthMinusOne.has_value())
3044 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3045 "border_x ")
3046 << "to be wholly divisible by scale_x_d, got ((" << iw
3047 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3048 << ") / " << scaleXD;
3049 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3050 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3051 return emitOpError("calculated output width did not match expected: ")
3052 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3053 }
3054
3055 return success();
3056}
3057
3058LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3059 MLIRContext *context, ::std::optional<Location> location,
3060 ScatterOp::Adaptor adaptor,
3061 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3062 llvm::SmallVector<int64_t> outputShape;
3063 outputShape.resize(3, ShapedType::kDynamic);
3064
3065 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3066 if (valuesInShape.hasRank()) {
3067 outputShape[0] = valuesInShape.getDimSize(0);
3068 outputShape[1] = valuesInShape.getDimSize(1);
3069 outputShape[2] = valuesInShape.getDimSize(2);
3070 }
3071
3072 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3073 if (indicesShape.hasRank()) {
3074 if (outputShape[0] == ShapedType::kDynamic)
3075 outputShape[0] = indicesShape.getDimSize(0);
3076 }
3077
3078 ShapeAdaptor inputShape(adaptor.getInput().getType());
3079 if (inputShape.hasRank()) {
3080 if (outputShape[0] == ShapedType::kDynamic)
3081 outputShape[0] = inputShape.getDimSize(0);
3082 if (outputShape[2] == ShapedType::kDynamic)
3083 outputShape[2] = inputShape.getDimSize(2);
3084 }
3085
3086 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3087 return success();
3088}
3089
3090LogicalResult tosa::ScatterOp::verify() {
3091 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3092 /* outType = */ getValuesOut().getType())
3093 .failed() ||
3094 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3095 /* outType = */ getValuesOut().getType())
3096 .failed()) {
3097 return failure();
3098 }
3099
3100 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3101 const ShapeAdaptor indicesShape(getIndices().getType());
3102 const ShapeAdaptor inputShape(getInput().getType());
3103 const ShapeAdaptor outputShape(getValuesOut().getType());
3104
3105 int64_t N = ShapedType::kDynamic;
3106 int64_t K = ShapedType::kDynamic;
3107 int64_t W = ShapedType::kDynamic;
3108 int64_t C = ShapedType::kDynamic;
3109 if (valuesInShape.hasRank()) {
3110 N = valuesInShape.getDimSize(0);
3111 K = valuesInShape.getDimSize(1);
3112 C = valuesInShape.getDimSize(2);
3113 }
3114 if (indicesShape.hasRank()) {
3115 const int64_t indicesN = indicesShape.getDimSize(0);
3116 W = indicesShape.getDimSize(1);
3117 if (N == ShapedType::kDynamic)
3118 N = indicesN;
3119 else if (indicesN != ShapedType::kDynamic && N != indicesN)
3120 return emitOpError() << "requires indices dimension 0 to have size " << N
3121 << ", got " << indicesN;
3122 }
3123 if (inputShape.hasRank()) {
3124 const int64_t inputN = inputShape.getDimSize(0);
3125 const int64_t inputW = inputShape.getDimSize(1);
3126 const int64_t inputC = inputShape.getDimSize(2);
3127 if (N == ShapedType::kDynamic)
3128 N = inputN;
3129 else if (inputN != ShapedType::kDynamic && N != inputN)
3130 return emitOpError() << "requires input dimension 0 to have size " << N
3131 << ", got " << inputN;
3132 if (W == ShapedType::kDynamic)
3133 W = inputW;
3134 else if (inputW != ShapedType::kDynamic && W != inputW)
3135 return emitOpError() << "requires input dimension 1 to have size " << W
3136 << ", got " << inputW;
3137
3138 if (C == ShapedType::kDynamic)
3139 C = inputC;
3140 else if (inputC != ShapedType::kDynamic && C != inputC)
3141 return emitOpError() << "requires input dimension 2 to have size " << C
3142 << ", got " << inputC;
3143 }
3144 if (outputShape.hasRank()) {
3145 const int64_t outputN = outputShape.getDimSize(0);
3146 const int64_t outputK = outputShape.getDimSize(1);
3147 const int64_t outputC = outputShape.getDimSize(2);
3148 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3149 N != outputN)
3150 return emitOpError() << "requires values_out dimension 0 to have size "
3151 << N << ", got " << outputN;
3152 if (K == ShapedType::kDynamic)
3153 K = outputK;
3154 else if (outputK != ShapedType::kDynamic && K != outputK)
3155 return emitOpError() << "requires values_out dimension 1 to have size "
3156 << K << ", got " << outputK;
3157 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3158 C != outputC)
3159 return emitOpError() << "requires values_out dimension 2 to have size "
3160 << C << ", got " << outputC;
3161 }
3162 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
3163 return emitOpError() << "requires dimensions K >= W, got K=" << K
3164 << " and W=" << W;
3165
3166 return success();
3167}
3168
3169static LogicalResult ReduceInferReturnTypes(
3170 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3171 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3172 int64_t axisVal = axis.getValue().getSExtValue();
3173 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3174 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3175 return success();
3176 }
3177
3178 SmallVector<int64_t> outputShape;
3179 operandShape.getDims(outputShape);
3180 outputShape[axisVal] = 1;
3181 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3182 return success();
3183}
3184
3185#define COMPATIBLE_RETURN_TYPES(OP) \
3186 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3187 if (l.size() != r.size() || l.size() != 1) \
3188 return false; \
3189 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3190 return false; \
3191 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3192 }
3193
3194#define REDUCE_SHAPE_INFER(OP) \
3195 LogicalResult OP::inferReturnTypeComponents( \
3196 MLIRContext *context, ::std::optional<Location> location, \
3197 OP::Adaptor adaptor, \
3198 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3199 Type inputType = \
3200 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3201 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3202 const Properties &prop = adaptor.getProperties(); \
3203 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3204 inferredReturnShapes); \
3205 } \
3206 COMPATIBLE_RETURN_TYPES(OP)
3207
3208REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3209REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3210REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3211REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3212REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3213REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3214#undef REDUCE_SHAPE_INFER
3215COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3216#undef COMPATIBLE_RETURN_TYPES
3217
3218template <typename T>
3219static LogicalResult verifyReduceOp(T op) {
3220 // All TOSA reduce Ops have input, output and axis.
3221 TensorType inputType = op.getInput().getType();
3222 TensorType outputType = op.getOutput().getType();
3223 int32_t reduceAxis = op.getAxis();
3224
3225 if (reduceAxis < 0) {
3226 op.emitOpError("reduce axis must not be negative");
3227 return failure();
3228 }
3229 if (inputType.hasRank()) {
3230 int64_t inputRank = inputType.getRank();
3231 // We allow for a special case where the input/output shape has rank 0 and
3232 // axis is also 0.
3233 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3234 op.emitOpError("expect input tensor rank (")
3235 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3236 << ")";
3237 return failure();
3238 }
3239 }
3240 if (outputType.hasRank()) {
3241 int64_t outputRank = outputType.getRank();
3242 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3243 op.emitOpError(
3244 "expect output tensor rank to be equal to input tensor rank");
3245 return failure();
3246 }
3247 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3248 op.emitOpError("expect output tensor rank (")
3249 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3250 << ")";
3251 return failure();
3252 }
3253 // We can only verify the reduced dimension size to be 1 if this is not
3254 // the special case of output rank == 0.
3255 if (outputRank != 0) {
3256 auto outputShape = outputType.getShape();
3257 if (!outputType.isDynamicDim(reduceAxis) &&
3258 outputShape[reduceAxis] != 1) {
3259 op.emitOpError("expect reduced dimension size to be 1, got ")
3260 << outputShape[reduceAxis];
3261 return failure();
3262 }
3263 }
3264 }
3265 return success();
3266}
3267
3268LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3269LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3270LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3271LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3272LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3273LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3274
3275static LogicalResult NAryInferReturnTypes(
3276 const ValueShapeRange &operands,
3277 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3279 if (resolveBroadcastShape(operands, outShape).failed()) {
3280 inferredReturnShapes.push_back(ShapedTypeComponents());
3281 } else {
3282 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3283 }
3284 return success();
3285}
3286
3287#define NARY_SHAPE_INFER(OP) \
3288 LogicalResult OP::inferReturnTypeComponents( \
3289 MLIRContext *context, ::std::optional<Location> location, \
3290 ValueShapeRange operands, DictionaryAttr attributes, \
3291 OpaqueProperties properties, RegionRange regions, \
3292 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3293 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3294 }
3295
3296NARY_SHAPE_INFER(tosa::AbsOp)
3297NARY_SHAPE_INFER(tosa::AddOp)
3298NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3299NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3300NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3301NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3302NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3303NARY_SHAPE_INFER(tosa::CastOp)
3304NARY_SHAPE_INFER(tosa::CeilOp)
3305NARY_SHAPE_INFER(tosa::ClampOp)
3306NARY_SHAPE_INFER(tosa::ClzOp)
3307NARY_SHAPE_INFER(tosa::CosOp)
3308NARY_SHAPE_INFER(tosa::ExpOp)
3309NARY_SHAPE_INFER(tosa::FloorOp)
3310NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3311NARY_SHAPE_INFER(tosa::GreaterOp)
3312NARY_SHAPE_INFER(tosa::IdentityOp)
3313NARY_SHAPE_INFER(tosa::IntDivOp)
3314NARY_SHAPE_INFER(tosa::LogOp)
3315NARY_SHAPE_INFER(tosa::LogicalAndOp)
3316NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3317NARY_SHAPE_INFER(tosa::LogicalNotOp)
3318NARY_SHAPE_INFER(tosa::LogicalOrOp)
3319NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3320NARY_SHAPE_INFER(tosa::LogicalXorOp)
3321NARY_SHAPE_INFER(tosa::MaximumOp)
3322NARY_SHAPE_INFER(tosa::MinimumOp)
3323NARY_SHAPE_INFER(tosa::PowOp)
3324NARY_SHAPE_INFER(tosa::ReciprocalOp)
3325NARY_SHAPE_INFER(tosa::ReverseOp)
3326NARY_SHAPE_INFER(tosa::RsqrtOp)
3327NARY_SHAPE_INFER(tosa::SinOp)
3328NARY_SHAPE_INFER(tosa::SelectOp)
3329NARY_SHAPE_INFER(tosa::SubOp)
3330NARY_SHAPE_INFER(tosa::TanhOp)
3331NARY_SHAPE_INFER(tosa::ErfOp)
3332NARY_SHAPE_INFER(tosa::SigmoidOp)
3333#undef PRED_SHAPE_INFER
3334
3335LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3336 MLIRContext *context, ::std::optional<Location> location,
3337 NegateOp::Adaptor adaptor,
3338 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3339 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3340 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3341 return success();
3342}
3343
3344LogicalResult tosa::NegateOp::verify() {
3345 // Verify same element type
3346 const Type input1Type = getInput1().getType();
3347 const Type outputType = getOutput().getType();
3348 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3349 return failure();
3350
3351 // Verify same shape
3352 const SmallVector<Type, 2> types = {input1Type, outputType};
3353 if (failed(verifyCompatibleShapes(types)))
3354 return emitOpError() << "requires the same shape for input1 and output";
3355
3356 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3357 const Type input1ZpEType =
3358 getStorageElementTypeOrSelf(getInput1Zp().getType());
3359 if (input1EType != input1ZpEType) {
3360 return emitOpError("expect both input1 and its zero point are the same "
3361 "element type, got ")
3362 << input1EType << " and " << input1ZpEType;
3363 }
3364 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3365 const Type outputZpEType =
3366 getStorageElementTypeOrSelf(getOutputZp().getType());
3367 if (outputEType != outputZpEType) {
3368 return emitOpError("expect both output and its zero point are the same "
3369 "element type, got ")
3370 << outputEType << " and " << outputZpEType;
3371 }
3372
3373 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3374 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3375 return failure();
3376
3377 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3378 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3379 return failure();
3380
3381 return success();
3382}
3383
3384static LogicalResult poolingInferReturnTypes(
3385 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3387 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3388 llvm::SmallVector<int64_t> outputShape;
3389 outputShape.resize(4, ShapedType::kDynamic);
3390
3391 // We only know the rank if the input type is unranked.
3392 if (!inputShape) {
3393 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3394 return success();
3395 }
3396
3397 // Batch and number of channels are identical for pooling layer.
3398 outputShape[0] = inputShape.getDimSize(0);
3399 outputShape[3] = inputShape.getDimSize(3);
3400
3401 int64_t height = inputShape.getDimSize(1);
3402 int64_t width = inputShape.getDimSize(2);
3403
3404 if (ShapedType::isStatic(height)) {
3405 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3406 outputShape[1] = padded / stride[0] + 1;
3407 }
3408
3409 if (ShapedType::isStatic(width)) {
3410 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3411 outputShape[2] = padded / stride[1] + 1;
3412 }
3413
3414 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3415 return success();
3416}
3417
3418LogicalResult Conv2DOp::inferReturnTypeComponents(
3419 MLIRContext *context, ::std::optional<Location> location,
3420 Conv2DOp::Adaptor adaptor,
3421 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3422 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3423
3424 int64_t inputWidth = ShapedType::kDynamic;
3425 int64_t inputHeight = ShapedType::kDynamic;
3426 int64_t weightWidth = ShapedType::kDynamic;
3427 int64_t weightHeight = ShapedType::kDynamic;
3428
3429 // Input shape describes input width/height and batch.
3430
3431 ShapeAdaptor inputShape(adaptor.getInput().getType());
3432 if (inputShape.hasRank()) {
3433 outputShape[0] = inputShape.getDimSize(0);
3434 inputHeight = inputShape.getDimSize(1);
3435 inputWidth = inputShape.getDimSize(2);
3436 }
3437
3438 // Weight shapes describes the filter width/height and the output channels.
3439 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3440 if (weightShape.hasRank()) {
3441 outputShape[3] = weightShape.getDimSize(0);
3442 weightHeight = weightShape.getDimSize(1);
3443 weightWidth = weightShape.getDimSize(2);
3444 }
3445
3446 // Bias shape can describe the output channels.
3447 ShapeAdaptor biasShape(adaptor.getBias().getType());
3448 if (biasShape.hasRank()) {
3449 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3450 ? biasShape.getDimSize(0)
3451 : outputShape[3];
3452 }
3453
3454 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3455 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3456 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3457
3458 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3459 int64_t inputSize = inputHeight + padding[0] + padding[1];
3460 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3461 int64_t unstridedResult = inputSize - filterSize + 1;
3462 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3463 }
3464
3465 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3466 int64_t inputSize = inputWidth + padding[2] + padding[3];
3467 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3468 int64_t unstridedResult = inputSize - filterSize + 1;
3469 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3470 }
3471
3472 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3473 return success();
3474}
3475
3476LogicalResult Conv2DOp::verify() {
3477 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3478 verifyConvOpErrorIf(*this).failed())
3479 return failure();
3480 return success();
3481}
3482
3483LogicalResult Conv3DOp::inferReturnTypeComponents(
3484 MLIRContext *context, ::std::optional<Location> location,
3485 Conv3DOp::Adaptor adaptor,
3486 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3487 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3488
3489 int64_t inputWidth = ShapedType::kDynamic;
3490 int64_t inputHeight = ShapedType::kDynamic;
3491 int64_t inputDepth = ShapedType::kDynamic;
3492
3493 int64_t weightWidth = ShapedType::kDynamic;
3494 int64_t weightHeight = ShapedType::kDynamic;
3495 int64_t weightDepth = ShapedType::kDynamic;
3496
3497 // Input shape describes input width/height and batch.
3498 ShapeAdaptor inputShape(adaptor.getInput().getType());
3499 if (inputShape.hasRank()) {
3500 outputShape[0] = inputShape.getDimSize(0);
3501 inputDepth = inputShape.getDimSize(1);
3502 inputHeight = inputShape.getDimSize(2);
3503 inputWidth = inputShape.getDimSize(3);
3504 }
3505
3506 // Weight shapes describes the filter width/height and the output channels.
3507 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3508 if (weightShape.hasRank()) {
3509 outputShape[4] = weightShape.getDimSize(0);
3510 weightDepth = weightShape.getDimSize(1);
3511 weightHeight = weightShape.getDimSize(2);
3512 weightWidth = weightShape.getDimSize(3);
3513 }
3514
3515 // Bias shape can describe the output channels.
3516 ShapeAdaptor biasShape(adaptor.getBias().getType());
3517 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3518 outputShape[4] = biasShape.getDimSize(0);
3519 }
3520
3521 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3522 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3523 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3524
3525 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3526 int32_t inputSize = inputDepth + pad[0] + pad[1];
3527 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3528 int32_t unstridedResult = inputSize - filterSize + 1;
3529 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3530 }
3531
3532 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3533 int32_t inputSize = inputHeight + pad[2] + pad[3];
3534 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3535 int32_t unstridedResult = inputSize - filterSize + 1;
3536 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3537 }
3538
3539 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3540 int32_t inputSize = inputWidth + pad[4] + pad[5];
3541 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3542 int32_t unstridedResult = inputSize - filterSize + 1;
3543 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3544 }
3545
3546 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3547 return success();
3548}
3549
3550LogicalResult Conv3DOp::verify() {
3551 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3552 verifyConvOpErrorIf(*this).failed())
3553 return failure();
3554 return success();
3555}
3556
3557LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3558 MLIRContext *context, ::std::optional<Location> location,
3559 AvgPool2dOp::Adaptor adaptor,
3560 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3561 ShapeAdaptor inputShape(adaptor.getInput().getType());
3562 const Properties &prop = adaptor.getProperties();
3563 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3564 inferredReturnShapes);
3565}
3566
3567LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3568 MLIRContext *context, ::std::optional<Location> location,
3569 MaxPool2dOp::Adaptor adaptor,
3570 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3571 ShapeAdaptor inputShape(adaptor.getInput().getType());
3572 const Properties &prop = adaptor.getProperties();
3573 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3574 inferredReturnShapes);
3575}
3576
3577LogicalResult MaxPool2dOp::verify() {
3578 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3579 /* outType = */ getOutput().getType())))
3580 return failure();
3581
3582 if (failed(verifyPoolingOp(*this)))
3583 return failure();
3584
3585 return success();
3586}
3587
3588LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3589 MLIRContext *context, ::std::optional<Location> location,
3590 DepthwiseConv2DOp::Adaptor adaptor,
3591 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3592 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3593
3594 int64_t inputWidth = ShapedType::kDynamic;
3595 int64_t inputHeight = ShapedType::kDynamic;
3596 int64_t inputChannels = ShapedType::kDynamic;
3597
3598 int64_t weightWidth = ShapedType::kDynamic;
3599 int64_t weightHeight = ShapedType::kDynamic;
3600 int64_t depthChannels = ShapedType::kDynamic;
3601
3602 // Input shape describes input width/height and batch.
3603 ShapeAdaptor inputShape(adaptor.getInput().getType());
3604 if (inputShape.hasRank()) {
3605 outputShape[0] = inputShape.getDimSize(0);
3606 inputHeight = inputShape.getDimSize(1);
3607 inputWidth = inputShape.getDimSize(2);
3608 inputChannels = inputShape.getDimSize(3);
3609 }
3610
3611 // Weight shapes describes the filter width/height and the output channels.
3612 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3613 if (weightShape.hasRank()) {
3614 weightHeight = weightShape.getDimSize(0);
3615 weightWidth = weightShape.getDimSize(1);
3616 inputChannels = ShapedType::isDynamic(inputChannels)
3617 ? weightShape.getDimSize(2)
3618 : inputChannels;
3619 depthChannels = weightShape.getDimSize(3);
3620 }
3621
3622 // If both inputChannels and depthChannels are available we can determine
3623 // the output channels.
3624 if (ShapedType::isStatic(inputChannels) &&
3625 ShapedType::isStatic(depthChannels)) {
3626 outputShape[3] = inputChannels * depthChannels;
3627 }
3628
3629 // Bias shape can describe the output channels.
3630 ShapeAdaptor biasShape(adaptor.getBias().getType());
3631 if (biasShape.hasRank()) {
3632 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3633 ? biasShape.getDimSize(0)
3634 : outputShape[3];
3635 }
3636
3637 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3638 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3639 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3640
3641 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3642 int64_t inputSize = inputHeight + padding[0] + padding[1];
3643 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3644 int64_t unstridedResult = inputSize - filterSize + 1;
3645 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3646 }
3647
3648 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3649 int64_t inputSize = inputWidth + padding[2] + padding[3];
3650 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3651 int64_t unstridedResult = inputSize - filterSize + 1;
3652 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3653 }
3654
3655 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3656 return success();
3657}
3658
3659LogicalResult DepthwiseConv2DOp::verify() {
3660 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3661 verifyConvOpErrorIf(*this).failed())
3662 return failure();
3663 return success();
3664}
3665
3666LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3667 MLIRContext *context, ::std::optional<Location> location,
3668 TransposeConv2DOp::Adaptor adaptor,
3669 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3670 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3671
3672 int64_t inputWidth = ShapedType::kDynamic;
3673 int64_t inputHeight = ShapedType::kDynamic;
3674 int64_t weightWidth = ShapedType::kDynamic;
3675 int64_t weightHeight = ShapedType::kDynamic;
3676
3677 // Input shape describes input width/height and batch.
3678 ShapeAdaptor inputShape(adaptor.getInput().getType());
3679 if (inputShape.hasRank()) {
3680 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3681 ? inputShape.getDimSize(0)
3682 : outputShape[0];
3683 inputHeight = inputShape.getDimSize(1);
3684 inputWidth = inputShape.getDimSize(2);
3685 }
3686
3687 // Weight shapes describes the filter width/height and the output channels.
3688 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3689 if (weightShape.hasRank()) {
3690 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3691 ? weightShape.getDimSize(0)
3692 : outputShape[3];
3693 weightHeight = weightShape.getDimSize(1);
3694 weightWidth = weightShape.getDimSize(2);
3695 }
3696
3697 // Bias shape can describe the output channels.
3698 ShapeAdaptor biasShape(adaptor.getInput().getType());
3699 if (biasShape.hasRank()) {
3700 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3701 ? biasShape.getDimSize(0)
3702 : outputShape[3];
3703 }
3704
3705 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3706 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3707
3708 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3709 int64_t calculateSize =
3710 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3711 outputShape[1] =
3712 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3713 }
3714
3715 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3716 int64_t calculateSize =
3717 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3718 outputShape[2] =
3719 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3720 }
3721
3722 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3723 return success();
3724}
3725
3726LogicalResult TransposeConv2DOp::verify() {
3727 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3728 return failure();
3729
3730 const llvm::ArrayRef<int64_t> strides = getStride();
3731 const int64_t strideY = strides[0];
3732 const int64_t strideX = strides[1];
3733
3734 if (strideY < 1 || strideX < 1)
3735 return emitOpError("expect all stride values to be >= 1, got [")
3736 << strides << "]";
3737
3738 const auto checkPadAgainstKernelDim =
3739 [this](int64_t pad_value, int64_t kernel_dim_size,
3740 llvm::StringRef pad_name,
3741 llvm::StringRef kernel_dim_name) -> LogicalResult {
3742 if (pad_value <= -kernel_dim_size)
3743 return emitOpError("expected ")
3744 << pad_name << " > -" << kernel_dim_name
3745 << ", but got: " << pad_name << "=" << pad_value << " and "
3746 << kernel_dim_name << "=" << kernel_dim_size;
3747 return success();
3748 };
3749
3750 const llvm::ArrayRef<int64_t> padding = getOutPad();
3751 const int64_t outPadTop = padding[0];
3752 const int64_t outPadBottom = padding[1];
3753 const int64_t outPadLeft = padding[2];
3754 const int64_t outPadRight = padding[3];
3755
3756 const auto weightType =
3757 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3758
3759 if (weightType) {
3760 const int64_t kernelHeight = weightType.getDimSize(1);
3761 if (ShapedType::isStatic(kernelHeight)) {
3762 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3763 "out_pad_top", "KH")))
3764 return failure();
3765
3766 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3767 "out_pad_bottom", "KH")))
3768 return failure();
3769 }
3770
3771 const int64_t kernelWidth = weightType.getDimSize(2);
3772 if (ShapedType::isStatic(kernelWidth)) {
3773 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3774 "out_pad_left", "KW")))
3775 return failure();
3776
3777 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3778 "out_pad_right", "KW")))
3779 return failure();
3780 }
3781 }
3782
3783 // Rest of the checks depend on the output type being a RankedTensorType
3784 const auto outputType =
3785 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3786 if (!outputType)
3787 return success();
3788
3789 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3790 if (inputType && weightType) {
3791 const int64_t inputHeight = inputType.getDimSize(1);
3792 const int64_t kernelHeight = weightType.getDimSize(1);
3793 const int64_t outputHeight = outputType.getDimSize(1);
3794
3795 if (ShapedType::isStatic(inputHeight) &&
3796 ShapedType::isStatic(outputHeight)) {
3797 if (outputHeight !=
3798 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3799 return emitOpError(
3800 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3801 "+ out_pad_top + out_pad_bottom + KH, but got ")
3802 << outputHeight << " != (" << inputHeight << " - 1) * "
3803 << strideY << " + " << outPadTop << " + " << outPadBottom
3804 << " + " << kernelHeight;
3805 }
3806
3807 const int64_t inputWidth = inputType.getDimSize(2);
3808 const int64_t kernelWidth = weightType.getDimSize(2);
3809 const int64_t outputWidth = outputType.getDimSize(2);
3810
3811 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3812 if (outputWidth !=
3813 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3814 return emitOpError(
3815 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3816 "+ out_pad_left + out_pad_right + KW, but got ")
3817 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3818 << " + " << outPadLeft << " + " << outPadRight << " + "
3819 << kernelWidth;
3820 }
3821 }
3822
3823 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3824
3825 if (!biasType)
3826 return success();
3827
3828 const int64_t biasChannels = biasType.getDimSize(0);
3829
3830 // Skip further checks if bias is dynamic
3831 if (biasChannels == ShapedType::kDynamic)
3832 return success();
3833
3834 const int64_t outputChannels = outputType.getDimSize(3);
3835 if (!ShapedType::isDynamic(outputChannels) &&
3836 biasChannels != outputChannels && biasChannels != 1)
3837 return emitOpError(
3838 "bias channels expected to be equal to output channels (")
3839 << outputChannels << ") or 1, got " << biasChannels;
3840
3841 return success();
3842}
3843
3844LogicalResult RescaleOp::verify() {
3845 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3846 if (!inputType) {
3847 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3848 return failure();
3849 }
3850
3851 auto inputElementType =
3852 getStorageElementTypeOrSelf(inputType.getElementType());
3853 if (!mlir::isa<IntegerType>(inputElementType)) {
3854 emitOpError("expect input to have integer element type, got ")
3855 << inputElementType;
3856 return failure();
3857 }
3858
3859 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3860 if (!outputType) {
3861 emitOpError("expect shaped tensor for output, got ")
3862 << getOutput().getType();
3863 return failure();
3864 }
3865
3866 auto outputElementType =
3867 getStorageElementTypeOrSelf(outputType.getElementType());
3868 if (!mlir::isa<IntegerType>(outputElementType)) {
3869 emitOpError("expect output to have integer element type, got ")
3870 << outputElementType;
3871 return failure();
3872 }
3873
3874 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3875 .failed())
3876 return failure();
3877
3878 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3879 .failed())
3880 return failure();
3881
3882 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3883 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3884 return failure();
3885
3886 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3887 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3888 return failure();
3889
3890 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3891 if (!multiplierType) {
3892 emitOpError("expect shaped tensor for multiplier, got ")
3893 << getMultiplier().getType();
3894 return failure();
3895 }
3896
3897 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3898 if (!shiftType) {
3899 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3900 return failure();
3901 }
3902
3903 // multiplier element type must be i32 for scale32 = true
3904 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3905 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3906 << multiplierType.getElementType();
3907 return failure();
3908 }
3909
3910 // multiplier element type must be i16 for scale32 = false
3911 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3913 "expect i16 element type for multiplier for scale32=false, got ")
3914 << multiplierType.getElementType();
3915 return failure();
3916 }
3917
3918 if (!inputType.hasRank())
3919 return success();
3920
3921 // multiplier/shift must have shape = {numChannels},
3922 // where numChannel is 1 if per_channel = false
3923 // otherwise numChannel is dimension in input shape's last axis
3924 int64_t numChannels = 1;
3925 if (getPerChannel()) {
3926 if (inputType.getRank() < 1) {
3927 emitOpError("requires input to be at least rank 1 when per_channel is "
3928 "true, but got rank ")
3929 << inputType.getRank();
3930 return failure();
3931 }
3932 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3933 }
3934
3935 if (!multiplierType.hasRank())
3936 return success();
3937
3938 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3939 // multiplier input has rank 1 by dialect definition
3940 if (multiplierShape[0] != ShapedType::kDynamic &&
3941 multiplierShape[0] != numChannels) {
3942 emitOpError("expect shape of { ")
3943 << numChannels << " } for multiplier input, got { "
3944 << multiplierShape[0] << " }";
3945 return failure();
3946 }
3947
3948 if (!shiftType.hasRank())
3949 return success();
3950
3951 ArrayRef<int64_t> shiftShape = shiftType.getShape();
3952 // shift input has rank 1 by dialect definition
3953 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3954 emitOpError("expect shape of { ")
3955 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3956 return failure();
3957 }
3958
3959 return success();
3960}
3961
3962LogicalResult RescaleOp::inferReturnTypeComponents(
3963 MLIRContext *context, ::std::optional<Location> location,
3964 RescaleOp::Adaptor adaptor,
3965 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3966 ShapeAdaptor inputShape(adaptor.getInput().getType());
3967 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3968 return success();
3969}
3970
3971LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
3972 MLIRContext *context, ::std::optional<Location> location,
3973 CastFromBlockScaledOp::Adaptor adaptor,
3974 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3975 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
3976 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3977 return success();
3978}
3979
3980LogicalResult CastFromBlockScaledOp::verify() {
3981 const Type inputDataType = getInputData().getType();
3982 const Type outputDataType = getResult().getType();
3983 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
3984 return emitOpError() << "require compatible shapes for input_data ("
3985 << inputDataType << ") and "
3986 << "output_data (" << outputDataType << ")";
3987
3988 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
3989
3990 if (inputDataShape.hasRank()) {
3991 const unsigned int blockSize =
3992 BlockSizeAttr::getBlockSizeValue(getBlockSize());
3993 const int64_t inputDataLastDim =
3994 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
3995 if (inputDataLastDim % blockSize != 0)
3996 return emitOpError() << "expect last dimension of input_data ("
3997 << inputDataLastDim
3998 << ") to be divisible by block_size (" << blockSize
3999 << ")";
4000
4001 const Type inputScaleType = getInputScale().getType();
4002 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4003
4004 if (inputScaleShape.hasRank()) {
4005 SmallVector<int64_t> inputDataDims, inputScaleDims;
4006 inputDataShape.getDims(inputDataDims);
4007 inputScaleShape.getDims(inputScaleDims);
4008
4009 if (inputDataDims.size() != inputScaleDims.size() ||
4011 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4012 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4013 return emitOpError() << "require compatible shapes for input_data ("
4014 << inputDataType << ") and "
4015 << "input_scale (" << inputScaleType
4016 << ") except for the last dimension";
4017
4018 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4019 inputScaleDims.back()};
4020 if (ShapedType::isStatic(inputDataLastDim) &&
4021 failed(verifyCompatibleDims(dimsToCheck)))
4022 return emitOpError()
4023 << "expect last dimension of input_scale ("
4024 << inputScaleDims.back()
4025 << ") to be equal to last dimension of input_data / block_size ("
4026 << inputDataDims.back() / blockSize << ")";
4027 }
4028 }
4029
4030 return success();
4031}
4032
4033LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4034 MLIRContext *context, ::std::optional<Location> location,
4035 CastToBlockScaledOp::Adaptor adaptor,
4036 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4037 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4038 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4039 if (!inputShape.hasRank())
4040 return success();
4041
4042 // Calculate output_scale shape if ranked input provided
4043 SmallVector<int64_t> outputScaleShape;
4044 inputShape.getDims(outputScaleShape);
4045 const int64_t lastDimLoc = inputShape.getRank() - 1;
4046 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4047 if (ShapedType::isStatic(lastDimSize)) {
4048 const unsigned int blockSize =
4049 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4050 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4051 }
4052 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4053 return success();
4054}
4055
4056LogicalResult CastToBlockScaledOp::verify() {
4057 const Type inputDataType = getInputData().getType();
4058 const Type outputDataType = getResult(0).getType();
4059 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4060 return emitOpError() << "require compatible shapes for input_data ("
4061 << inputDataType << ") and "
4062 << "output_data (" << outputDataType << ")";
4063
4064 const unsigned int blockSize =
4065 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4066 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4067 if (inputDataShape.hasRank()) {
4068 const int64_t inputDataLastDim =
4069 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4070 if (ShapedType::isStatic(inputDataLastDim) &&
4071 inputDataLastDim % blockSize != 0)
4072 return emitOpError() << "expect last dimension of input_data ("
4073 << inputDataLastDim
4074 << ") to be divisible by block_size (" << blockSize
4075 << ")";
4076 }
4077
4078 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4079 const Type outputScaleType = getResult(1).getType();
4080 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4081 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4082 SmallVector<int64_t> outputDataDims, outputScaleDims;
4083 outputDataShape.getDims(outputDataDims);
4084 outputScaleShape.getDims(outputScaleDims);
4085
4086 if (outputDataDims.size() != outputScaleDims.size() ||
4088 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4089 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4090 return emitOpError() << "require compatible shapes for output_data ("
4091 << outputDataType << ") and "
4092 << "output_scale (" << outputScaleType
4093 << ") except for the last dimension";
4094
4095 const int64_t outputDataLastDim = outputDataDims.back();
4096 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4097 outputScaleDims.back()};
4098 if (ShapedType::isStatic(outputDataLastDim) &&
4099 failed(verifyCompatibleDims(dimsToCheck)))
4100 return emitOpError()
4101 << "expect last dimension of output_scale ("
4102 << outputScaleDims.back()
4103 << ") to be equal to last dimension of output_data / block_size ("
4104 << outputDataDims.back() / blockSize << ")";
4105 }
4106
4107 return success();
4108}
4109
4110LogicalResult IfOp::inferReturnTypeComponents(
4111 MLIRContext *context, ::std::optional<Location> location,
4112 IfOp::Adaptor adaptor,
4113 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4114 llvm::SmallVector<tosa::YieldOp> yieldOps;
4115 for (Region *region : adaptor.getRegions()) {
4116 for (auto &block : *region)
4117 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4118 yieldOps.push_back(returnOp);
4119 }
4120
4121 if (yieldOps.empty())
4122 return failure();
4123
4124 // Get the initial type information for the yield op.
4125 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4126 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4127 for (auto operand : yieldOps.front().getOperands()) {
4128 resultKnowledge.push_back(
4129 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4130 }
4131
4132 for (auto yieldOp : yieldOps) {
4133 if (resultKnowledge.size() != yieldOp.getNumOperands())
4134 return failure();
4135
4136 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4137 int32_t index = it.index();
4138 auto meet = ValueKnowledge::meet(
4139 resultKnowledge[index],
4140 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4141 if (!meet)
4142 continue;
4143 resultKnowledge[index] = meet;
4144 }
4145 }
4146
4147 for (const ValueKnowledge &result : resultKnowledge) {
4148 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4149 }
4150
4151 return success();
4152}
4153
4154LogicalResult WhileOp::inferReturnTypeComponents(
4155 MLIRContext *context, ::std::optional<Location> location,
4156 WhileOp::Adaptor adaptor,
4157 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4158 llvm::SmallVector<tosa::YieldOp> yieldOps;
4159 for (auto &block : adaptor.getBodyGraph())
4160 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4161 yieldOps.push_back(returnOp);
4162
4163 // TOSA's while must have a tosa.yield as its terminator. If not found this
4164 // tosa.while is invalid.
4165 if (yieldOps.empty())
4166 return failure();
4167
4168 // Get the initial type information from the operand types.
4169 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4170 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4171 for (auto operand : yieldOps.front().getOperands()) {
4172 resultKnowledge.push_back(
4173 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4174 }
4175
4176 for (auto yieldOp : yieldOps) {
4177 if (resultKnowledge.size() != yieldOp.getNumOperands())
4178 return failure();
4179
4180 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4181 int32_t index = it.index();
4182 if (auto meet = ValueKnowledge::meet(
4183 resultKnowledge[index],
4184 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4185 resultKnowledge[index] = meet;
4186 }
4187 }
4188 }
4189
4190 for (const ValueKnowledge &result : resultKnowledge) {
4191 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4192 }
4193
4194 return success();
4195}
4196
4197std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4198 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4199 return llvm::to_vector<4>(vt.getShape());
4200 return std::nullopt;
4201}
4202
4204 Block::BlockArgListType blocksArgs,
4205 ValueRange initializers,
4206 StringRef prefix = "") {
4207 assert(blocksArgs.size() == initializers.size() &&
4208 "expected same length of arguments and initializers");
4209 if (initializers.empty())
4210 return;
4211
4212 parser << prefix << '(';
4213 llvm::interleaveComma(
4214 llvm::zip(blocksArgs, initializers), parser,
4215 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4216 parser << ")";
4217}
4218
4219// parse and print of IfOp refer to the implementation of SCF dialect.
4220ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4221 // Create the regions for 'then'.
4222 result.regions.reserve(2);
4223 Region *thenRegion = result.addRegion();
4224 Region *elseRegion = result.addRegion();
4225
4226 OpAsmParser::UnresolvedOperand cond;
4227
4228 if (parser.parseOperand(cond))
4229 return failure();
4230
4231 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4232 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4233
4234 // Parse the optional block arguments
4235 OptionalParseResult listResult =
4236 parser.parseOptionalAssignmentList(regionArgs, operands);
4237 if (listResult.has_value() && failed(listResult.value()))
4238 return failure();
4239
4240 // Parse a colon.
4241 if (failed(parser.parseColon()))
4242 return parser.emitError(parser.getCurrentLocation(),
4243 "expected type for condition operand");
4244
4245 // Parse the type of the condition operand
4246 Type condType;
4247 if (failed(parser.parseType(condType)))
4248 return parser.emitError(parser.getCurrentLocation(),
4249 "expected type for condition operand");
4250
4251 // Resolve operand with provided type
4252 if (failed(parser.resolveOperand(cond, condType, result.operands)))
4253 return failure();
4254
4255 // Parse optional block arg types
4256 if (listResult.has_value()) {
4257 FunctionType functionType;
4258
4259 if (failed(parser.parseType(functionType)))
4260 return parser.emitError(parser.getCurrentLocation())
4261 << "expected list of types for block arguments "
4262 << "followed by arrow type and list of return types";
4263
4264 result.addTypes(functionType.getResults());
4265
4266 if (functionType.getNumInputs() != operands.size()) {
4267 return parser.emitError(parser.getCurrentLocation())
4268 << "expected as many input types as operands "
4269 << "(expected " << operands.size() << " got "
4270 << functionType.getNumInputs() << ")";
4271 }
4272
4273 // Resolve input operands.
4274 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4275 parser.getCurrentLocation(),
4276 result.operands)))
4277 return failure();
4278 } else {
4279 // Parse optional results type list.
4280 if (parser.parseOptionalArrowTypeList(result.types))
4281 return failure();
4282 }
4283
4284 // Parse the 'then' region.
4285 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4286 return failure();
4287
4288 // If we find an 'else' keyword then parse the 'else' region.
4289 if (!parser.parseOptionalKeyword("else")) {
4290 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4291 return failure();
4292 }
4293
4294 // Parse the optional attribute list.
4295 if (parser.parseOptionalAttrDict(result.attributes))
4296 return failure();
4297 return success();
4298}
4299
4300void IfOp::print(OpAsmPrinter &p) {
4301 p << " " << getCondition();
4302
4303 printInitializationList(p, getThenGraph().front().getArguments(),
4304 getInputList(), " ");
4305 p << " : ";
4306 p << getCondition().getType();
4307
4308 if (!getInputList().empty()) {
4309 p << " (";
4310 llvm::interleaveComma(getInputList().getTypes(), p);
4311 p << ")";
4312 }
4313 p.printArrowTypeList(getResultTypes());
4314 p << " ";
4315
4316 p.printRegion(getThenGraph());
4317
4318 // Print the 'else' regions if it exists and has a block.
4319 auto &elseRegion = getElseGraph();
4320 if (!elseRegion.empty()) {
4321 p << " else ";
4322 p.printRegion(elseRegion);
4323 }
4324
4325 p.printOptionalAttrDict((*this)->getAttrs());
4326}
4327
4328LogicalResult IfOp::verify() {
4329 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4330 "'then_graph' arguments", getInputList(),
4331 "'input_list'")
4332 .failed())
4333 return failure();
4334
4335 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4336 "'else_graph' arguments", getInputList(),
4337 "'input_list'")
4338 .failed())
4339 return failure();
4340
4341 // MLIR will verify the absence of the terminator for us if otherwise.
4342 if (getThenGraph().front().mightHaveTerminator()) {
4343 auto thenYield =
4344 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4345 if (thenYield && errorIfTypeOrShapeMismatch(
4346 *this, thenYield.getInputs(), "'then_graph' results",
4347 getOutputList(), "'output_list'")
4348 .failed())
4349 return failure();
4350 }
4351
4352 // MLIR will verify the absence of the terminator for us if otherwise.
4353 if (getElseGraph().front().mightHaveTerminator()) {
4354 auto elseYield =
4355 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4356 if (elseYield && errorIfTypeOrShapeMismatch(
4357 *this, elseYield.getInputs(), "'else_graph' results",
4358 getOutputList(), "'output_list'")
4359 .failed())
4360 return failure();
4361 }
4362
4363 auto condType = getCondition().getType();
4364 if (errorIfShapeNotSizeOne(*this, condType).failed())
4365 return emitOpError() << "'condition' must be a size 1 tensor, got "
4366 << condType;
4367
4368 return success();
4369}
4370
4371LogicalResult WhileOp::verify() {
4372 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4373 getOutputList(), "'output_list'")
4374 .failed())
4375 return failure();
4376
4377 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4378 "'cond_graph' arguments", getInputList(),
4379 "'input_list'")
4380 .failed())
4381 return failure();
4382
4383 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4384 "'body_graph' arguments", getInputList(),
4385 "'input_list'")
4386 .failed())
4387 return failure();
4388
4389 if (getBodyGraph().front().mightHaveTerminator()) {
4390 auto bodyYield =
4391 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4392 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4393 "'body_graph' results",
4394 getInputList(), "'input_list'")
4395 .failed())
4396 return failure();
4397 }
4398
4399 // Condition block output must be a single element tensor with a single bool
4400 // value.
4401 if (!getCondGraph().front().mightHaveTerminator())
4402 return success();
4403
4404 auto condYield =
4405 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4406 if (!condYield)
4407 return success();
4408
4409 if (condYield.getInputs().size() != 1)
4410 return emitOpError() << "require 'cond_graph' only have one result";
4411
4412 auto condOutType = condYield.getInputs()[0].getType();
4413 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4414 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4415 << condOutType;
4416
4417 if (!getElementTypeOrSelf(condOutType).isInteger(1))
4418 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4419 << condOutType;
4420
4421 return success();
4422}
4423
4424LogicalResult ReverseOp::verify() {
4425 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4426 /* outType = */ getOutput().getType())
4427 .failed())
4428 return failure();
4429 TensorType inputType = getInput1().getType();
4430 TensorType outputType = getOutput().getType();
4431 int32_t reverseAxis = getAxis();
4432
4433 if (reverseAxis < 0)
4434 return emitOpError("expected non-negative reverse axis");
4435 if (inputType.hasRank()) {
4436 int64_t inputRank = inputType.getRank();
4437 // We allow for a special case where the input/output shape has rank 0 and
4438 // axis is also 0.
4439 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4440 return emitOpError("expect input tensor rank (")
4441 << inputRank << ") to be larger than reverse axis (" << reverseAxis
4442 << ")";
4443 }
4444 if (outputType.hasRank()) {
4445 int64_t outputRank = outputType.getRank();
4446 if (inputType.hasRank() && outputRank != inputType.getRank())
4447 return emitOpError(
4448 "expect output tensor rank to be equal to input tensor rank");
4449 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4450 return emitOpError("expect output tensor rank (")
4451 << outputRank << ") to be larger than reverse axis ("
4452 << reverseAxis << ")";
4453 }
4454 return success();
4455}
4456
4457LogicalResult tosa::SelectOp::verify() {
4458 // verify input2 and input3 have same element type as output
4459 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4460 /* outType = */ getOutput().getType())
4461 .failed() ||
4462 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4463 /* outType = */ getOutput().getType())
4464 .failed()) {
4465 return failure();
4466 }
4467 // verify input1 has element type of bool
4468 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4469 if (!predicateType) {
4470 return emitOpError("expect shaped tensor for input1, got ")
4471 << getInput1().getType();
4472 }
4473 auto predicateElementType = predicateType.getElementType();
4474 if (!predicateElementType.isInteger(1)) {
4475 return emitOpError("expect element type of bool for input1, got ")
4476 << predicateElementType;
4477 }
4478
4479 return success();
4480}
4481
4482LogicalResult tosa::VariableReadOp::verify() {
4483 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4484 .failed())
4485 return failure();
4486
4487 return success();
4488}
4489
4490LogicalResult tosa::VariableWriteOp::verify() {
4491 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4492 .failed())
4493 return failure();
4494
4495 return success();
4496}
4497
4498// parse and print of WhileOp refer to the implementation of SCF dialect.
4499ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4500 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4501 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4502 Region *cond = result.addRegion();
4503 Region *body = result.addRegion();
4504
4505 OptionalParseResult listResult =
4506 parser.parseOptionalAssignmentList(regionArgs, operands);
4507 if (listResult.has_value() && failed(listResult.value()))
4508 return failure();
4509
4510 FunctionType functionType;
4511 SMLoc typeLoc = parser.getCurrentLocation();
4512 if (failed(parser.parseColonType(functionType)))
4513 return failure();
4514
4515 result.addTypes(functionType.getResults());
4516
4517 if (functionType.getNumInputs() != operands.size()) {
4518 return parser.emitError(typeLoc)
4519 << "expected as many input types as operands "
4520 << "(expected " << operands.size() << " got "
4521 << functionType.getNumInputs() << ")";
4522 }
4523
4524 // Resolve input operands.
4525 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4526 parser.getCurrentLocation(),
4527 result.operands)))
4528 return failure();
4529
4530 // Propagate the types into the region arguments.
4531 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4532 regionArgs[i].type = functionType.getInput(i);
4533
4534 return failure(parser.parseRegion(*cond, regionArgs) ||
4535 parser.parseKeyword("do") || parser.parseRegion(*body) ||
4536 parser.parseOptionalAttrDictWithKeyword(result.attributes));
4537}
4538
4539void WhileOp::print(OpAsmPrinter &parser) {
4540 printInitializationList(parser, getCondGraph().front().getArguments(),
4541 getInputList(), " ");
4542 parser << " : ";
4543 parser.printFunctionalType(getInputList().getTypes(),
4544 getResults().getTypes());
4545 parser << ' ';
4546 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4547 parser << " do ";
4548 parser.printRegion(getBodyGraph());
4549 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4550}
4551
4552// Create a rank-1 const tensor for zero point of the source tensor.
4553std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4554 Location loc,
4555 Type srcElemType,
4556 int64_t zp) {
4557 srcElemType = getStorageElementTypeOrSelf(srcElemType);
4558 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4559 if (llvm::isa<FloatType>(srcElemType)) {
4560 auto zpAttr = DenseElementsAttr::get(
4561 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4562 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4563 }
4564 if (llvm::isa<IntegerType>(srcElemType)) {
4565 auto zpAttr =
4566 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4567 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4568 }
4569 llvm::errs() << "zero point is not allowed for unsupported data types\n";
4570 return std::nullopt;
4571}
4572
4573//===----------------------------------------------------------------------===//
4574// TOSA Shape and Shape Operators Helper functions.
4575//===----------------------------------------------------------------------===//
4576
4578 return mlir::isa<tosa::shapeType>(t);
4579}
4580
4581LogicalResult
4582mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
4583 int rank) {
4584 if (rank < 0)
4585 return emitError() << "invalid rank (must be >= 0): " << rank;
4586 return success();
4587}
4588
4590 for (auto v : op->getOperands()) {
4591 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4592 Operation *definingOp = v.getDefiningOp();
4593 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4594 return op->emitOpError("shape operand is not compile time resolvable");
4595 }
4596 }
4597 }
4598 return success();
4599}
4600
4602 for (auto type : op->getOperandTypes()) {
4603 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4604 return op->emitOpError("must have operands with tosa shape type");
4605 }
4606 }
4607 for (auto type : op->getResultTypes()) {
4608 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4609 return op->emitOpError("must have result with tosa shape type");
4610 }
4611 }
4612 return success();
4613}
4614
4615LogicalResult
4617 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
4618 failed(verifyTosaShapeOperator(op)))
4619 return failure();
4620
4621 // delegate function that returns rank of shape type
4622 auto getRank = [](const Type type) {
4623 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4624 };
4625 auto operandTypes = op->getOperandTypes();
4626 auto resultTypes = op->getResultTypes();
4627
4628 auto rank = getRank(*op->getOperandTypes().begin());
4629 for (auto type : operandTypes) {
4630 if (getRank(type) != rank) {
4631 return op->emitOpError("operands don't have matching ranks");
4632 }
4633 }
4634 for (auto type : resultTypes) {
4635 if (getRank(type) != rank) {
4636 return op->emitOpError("result shape has different rank than operands");
4637 }
4638 }
4639 return success();
4640}
4641
4642//===----------------------------------------------------------------------===//
4643// TOSA Shape Operators verify functions.
4644//===----------------------------------------------------------------------===//
4645
4646LogicalResult tosa::ConstShapeOp::verify() {
4647 // check one dimensional rank
4648 auto valuesRank = getValues().getType().getRank();
4649 if (valuesRank != 1)
4650 return emitOpError("expect elements in attribute values with rank 1");
4651 // check that number of elements in values attr equal to rank of result shape
4652 auto count = getValues().getNumElements();
4653 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4654 if (count != rank && (count != 1 || rank != 0)) {
4655 return emitOpError("expect number of elements in attribute values (")
4656 << count << ") to be equal to the rank (" << rank
4657 << ") for the result shape type";
4658 }
4659 return success();
4660}
4661
4662//===----------------------------------------------------------------------===//
4663// TOSA Attribute Definitions.
4664//===----------------------------------------------------------------------===//
4665
4666#define GET_ATTRDEF_CLASSES
4667#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4668
4669//===----------------------------------------------------------------------===//
4670// TOSA Type Definitions.
4671//===----------------------------------------------------------------------===//
4672#define GET_TYPEDEF_CLASSES
4673#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4674
4675//===----------------------------------------------------------------------===//
4676// TOSA Operator Definitions.
4677//===----------------------------------------------------------------------===//
4678
4679#define GET_OP_CLASSES
4680#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:565
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1296
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:984
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:136
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3169
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:574
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:948
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3194
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:620
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:957
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3384
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1410
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:557
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1424
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3219
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3287
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:2666
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1274
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:766
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2596
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:725
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3275
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:563
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3185
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1455
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1370
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1250
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1325
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:906
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:2624
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1050
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1544
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:193
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:207
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OperandRange operand_range
Definition Operation.h:371
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition TosaOps.cpp:4601
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:4616
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:4589
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:609
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4553
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:4577
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:594
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45