MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34#include <type_traits>
35
36using namespace mlir;
37using namespace mlir::tosa;
38
39#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// Tosa dialect interface includes.
44//===----------------------------------------------------------------------===//
45
46#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
49#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
50
51namespace {
52#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
53
54//===----------------------------------------------------------------------===//
55// Dialect Function Inliner Interface.
56//===----------------------------------------------------------------------===//
57struct TosaInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks.
62 //===--------------------------------------------------------------------===//
63
64 /// All operations can be inlined by default.
65 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
66 IRMapping &map) const final {
67 return true;
68 }
69
70 /// All regions with If and While parent operators can be inlined.
71 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72 IRMapping &map) const final {
73 return (isa<tosa::IfOp>(dest->getParentOp()) ||
74 isa<tosa::WhileOp>(dest->getParentOp()));
75 }
76};
77
78/// This class implements the bytecode interface for the Tosa dialect.
79struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
80 TosaDialectBytecodeInterface(Dialect *dialect)
81 : BytecodeDialectInterface(dialect) {}
82
83 //===--------------------------------------------------------------------===//
84 // Attributes
85
86 Attribute readAttribute(DialectBytecodeReader &reader) const override {
87 return ::readAttribute(getContext(), reader);
88 }
89
90 LogicalResult writeAttribute(Attribute attr,
91 DialectBytecodeWriter &writer) const override {
92 return ::writeAttribute(attr, writer);
93 }
94
95 //===--------------------------------------------------------------------===//
96 // Types
97
98 Type readType(DialectBytecodeReader &reader) const override {
99 return ::readType(getContext(), reader);
100 }
101
102 LogicalResult writeType(Type type,
103 DialectBytecodeWriter &writer) const override {
104 return ::writeType(type, writer);
105 }
106
107 void writeVersion(DialectBytecodeWriter &writer) const final {
108 // TODO: Populate.
109 }
110
111 std::unique_ptr<DialectVersion>
112 readVersion(DialectBytecodeReader &reader) const final {
113 // TODO: Populate
114 reader.emitError("Dialect does not support versioning");
115 return nullptr;
116 }
117
118 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 const DialectVersion &version) const final {
120 return success();
121 }
122};
123
124} // namespace
125
126//===----------------------------------------------------------------------===//
127// TOSA control flow support.
128//===----------------------------------------------------------------------===//
129
130/// Returns the while loop body.
131SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
132 return {&getBodyGraph()};
133}
134
135//===----------------------------------------------------------------------===//
136// TOSA variable operator support.
137//===----------------------------------------------------------------------===//
138
140 return map_to_vector(shape, [](int64_t dim) {
141 return dim == -1 ? ShapedType::kDynamic : dim;
142 });
143}
144
145// returns type of variable op
146RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
147 Type elementType = variableOp.getType();
148 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
149 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
150 return RankedTensorType::get(shape, elementType);
151}
152
153//===----------------------------------------------------------------------===//
154// Tosa dialect initialization.
155//===----------------------------------------------------------------------===//
156
157void TosaDialect::initialize() {
158 addTypes<
159#define GET_TYPEDEF_LIST
160#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 >();
162 addOperations<
163#define GET_OP_LIST
164#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
165 >();
166 addAttributes<
167#define GET_ATTRDEF_LIST
168#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 >();
170 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
171 declarePromisedInterfaces<
172 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
173 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
174 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
175 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
176 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
177 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
178 GreaterEqualOp, MatMulOp>();
179}
180
181Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
182 Type type, Location loc) {
183 // Tosa dialect constants only support ElementsAttr unlike standard dialect
184 // constant which supports all attributes.
185 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
186 return tosa::ConstShapeOp::create(builder, loc, type,
187 llvm::cast<DenseIntElementsAttr>(value));
188 }
189 if (llvm::isa<ElementsAttr>(value))
190 return tosa::ConstOp::create(builder, loc, type,
191 llvm::cast<ElementsAttr>(value));
192 return nullptr;
193}
194
195//===----------------------------------------------------------------------===//
196// Parsers and printers
197//===----------------------------------------------------------------------===//
198
199namespace {
200
201ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
202 DenseElementsAttr &varShapeAttr,
203 TypeAttr &typeAttr) {
204 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
205 if (!shapedType.hasRank())
206 return parser.emitError(parser.getCurrentLocation())
207 << "expected ranked type";
208
209 auto elementType = shapedType.getElementType();
210 typeAttr = TypeAttr::get(elementType);
211 ArrayRef<int64_t> shape = shapedType.getShape();
212 Builder builder(parser.getContext());
213 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
214 return success();
215 }
216 return parser.emitError(parser.getCurrentLocation())
217 << "expected shaped type";
218}
219
220} // namespace
221
222// parses the optional initial value or type for a tosa variable
223// with initial value:
224// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
225//
226// without initial value:
227// tosa.variable @name : tensor<1x8xf32>
229 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
230 Attribute &initialValueAttr) {
231 if (succeeded(parser.parseOptionalEqual())) {
232 if (failed(parser.parseAttribute(initialValueAttr))) {
233 return parser.emitError(parser.getCurrentLocation())
234 << "expected attribute";
235 }
236 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
237 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 typeAttr);
239 }
240 return parser.emitError(parser.getCurrentLocation())
241 << "expected Typed attr";
242 }
243
244 initialValueAttr = nullptr;
245 Type parsedType;
246 if (failed(parser.parseColonType(parsedType))) {
247 return parser.emitError(parser.getCurrentLocation())
248 << "expected type after colon";
249 }
250 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
251}
252
254 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
255 TypeAttr typeAttr, Attribute initialValueAttr) {
256 bool needsSpace = false;
257 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
258 auto shape =
259 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
260 Type elementType = typeAttr.getValue();
261 RankedTensorType tensorType =
262 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
263 auto tensorTypeAttr = TypeAttr::get(tensorType);
264 p << ": ";
265 p.printAttribute(tensorTypeAttr);
266 needsSpace = true; // subsequent attr value needs a space separator
267 }
268 if (initialValueAttr) {
269 if (needsSpace)
270 p << ' ';
271 p << "= ";
272 p.printAttribute(initialValueAttr);
273 }
274}
275
276namespace {
277
278// parse attributes with special handling for tosa enum attributes
279template <typename EnumType>
280ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
281 NamedAttrList &outAttrs) {
282 llvm::StringRef name;
283 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
284 return failure();
285
286 // special handling: rounding_mode accepts a *bare* RoundingMode enum
287 // keyword.
288 llvm::StringRef kw;
289 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
290 if (name == "rounding_mode" &&
291 succeeded(parser.parseOptionalKeyword(&kw))) {
292 auto sym = symbolizeRoundingMode(kw);
293 if (!sym)
294 return parser.emitError(parser.getCurrentLocation())
295 << "invalid rounding_mode value: " << kw;
296 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
297 outAttrs.push_back(NamedAttribute(name, attr));
298 return success();
299 }
300 }
301 // special handling: mode accepts a *bare* ResizeMode enum keyword.
302 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
303 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
304 auto sym = symbolizeResizeMode(kw);
305 if (!sym)
306 return parser.emitError(parser.getCurrentLocation())
307 << "invalid resize mode value: " << kw;
308 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
309 outAttrs.push_back(NamedAttribute(name, attr));
310 return success();
311 }
312 }
313 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
314 // keyword.
315 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
316 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
317 auto sym = symbolizeNanPropagationMode(kw);
318 if (!sym)
319 return parser.emitError(parser.getCurrentLocation())
320 << "invalid nan_mode value: " << kw;
321 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
322 outAttrs.push_back(NamedAttribute(name, attr));
323 return success();
324 }
325 }
326
327 // special handling: block_size accepts a *bare* BlockSizeMode enum
328 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
329 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
330 auto sym = symbolizeBlockSize(kw);
331 if (!sym)
332 return parser.emitError(parser.getCurrentLocation())
333 << "invalid block_size value: " << kw;
334 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
335 outAttrs.push_back(NamedAttribute(name, attr));
336 return success();
337 }
338 }
339
340 // Default path: parse any normal attribute literal, including fully qualified
341 // enum keyword
342 Attribute attr;
343 return parser.parseAttribute(attr, name, outAttrs);
344}
345
346template <typename EnumType>
347ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
348 // parse operands
350 if (parser.parseCommaSeparatedList(
351 [&]() { return parser.parseOperand(operands.emplace_back()); }))
352 return failure();
353
354 // Parse { attr-dict } with special handling for enum bare token
355 NamedAttrList attrs;
356 if (succeeded(parser.parseOptionalLBrace()) &&
357 failed(parser.parseOptionalRBrace())) {
358 do {
359 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
360 return failure();
361 } while (succeeded(parser.parseOptionalComma()));
362 if (parser.parseRBrace())
363 return failure();
364 }
365
366 FunctionType fnTy;
367 if (parser.parseColonType(fnTy))
368 return failure();
369
370 // Resolve operands and types
371 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
372 parser.getCurrentLocation(),
373 result.operands)))
374 return failure();
375
376 result.addTypes(fnTy.getResults());
377 result.addAttributes(attrs);
378
379 return success();
380}
381
382void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
383 parser << namedAttr.getName().strref() << " = ";
384 auto attr = namedAttr.getValue();
385 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
386 parser << roundingModeAttr.getValue();
387 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
388 parser << resizeModeAttr.getValue();
389 } else if (auto nanPropagationModeAttr =
390 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
391 parser << nanPropagationModeAttr.getValue();
392 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
393 parser << blockSizeAttr.getValue();
394 } else {
395 parser.printAttribute(attr);
396 }
397}
398
399// print with special handling for default valued NanPropagationMode attribute
400void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
401 parser << " ";
402 parser.printOperands(op->getOperands());
403
404 NamedAttrList toPrint(op->getAttrs());
405 // remove default NanPropagate attribute
406 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
407 for (auto attr : op->getAttrs()) {
408 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
409 if (nanAttr.getValue() == kDefaultNanValue) {
410 // elide from toPrint
411 toPrint.erase(attr.getName());
412 break;
413 }
414 }
415 }
416
417 if (!toPrint.empty()) {
418 parser << " {";
419 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
420 printNamedAttr(parser, namedAttr);
421 });
422 parser << "}";
423 }
424
425 parser << " : ";
426 parser.printFunctionalType(op);
427}
428
429// print with special handling for enums: RoundingMode, ResizeMode
430void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
431 parser << " ";
432 parser.printOperands(op->getOperands());
433
434 if (!op->getAttrs().empty()) {
435 parser << " {";
436 llvm::interleaveComma(op->getAttrs(), parser,
437 [&](const NamedAttribute namedAttr) {
438 printNamedAttr(parser, namedAttr);
439 });
440 parser << "}";
441 }
442
443 parser << " : ";
444 parser.printFunctionalType(op);
445}
446
447} // namespace
448
449ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
450 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
451}
452
453void RescaleOp::print(OpAsmPrinter &parser) {
454 printWithEnumHandling(parser, *this);
455}
456
457ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
458 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
459}
460
461void ApplyScaleOp::print(OpAsmPrinter &parser) {
462 printWithEnumHandling(parser, *this);
463}
464
465ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
466 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
467}
468
469void ResizeOp::print(OpAsmPrinter &parser) {
470 printWithEnumHandling(parser, *this);
471}
472
473ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
474 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
475}
476
477void ArgMaxOp::print(OpAsmPrinter &parser) {
478 printWithNanPropagationHandling(parser, *this);
479}
480
481ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
482 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
483}
484
485void MaxPool2dOp::print(OpAsmPrinter &parser) {
486 printWithNanPropagationHandling(parser, *this);
487}
488
489ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
491 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
492}
493
494void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
495 printWithNanPropagationHandling(parser, *this);
496}
497
498ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
499 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
500}
501
502void ClampOp::print(OpAsmPrinter &parser) {
503 printWithNanPropagationHandling(parser, *this);
504}
505
506ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
507 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
508}
509
510void MaximumOp::print(OpAsmPrinter &parser) {
511 printWithNanPropagationHandling(parser, *this);
512}
513
514ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
515 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
516}
517
518void MinimumOp::print(OpAsmPrinter &parser) {
519 printWithNanPropagationHandling(parser, *this);
520}
521
522ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
523 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
524}
525
526void ReduceMaxOp::print(OpAsmPrinter &parser) {
527 printWithNanPropagationHandling(parser, *this);
528}
529
530ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
531 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
532}
533
534void ReduceMinOp::print(OpAsmPrinter &parser) {
535 printWithNanPropagationHandling(parser, *this);
536}
537
538ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
540 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
541}
542
543void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
544 printWithEnumHandling(parser, *this);
545}
546
547ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
549 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
550}
551
552void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
553 printWithEnumHandling(parser, *this);
554}
555
556ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
558 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
559}
560
561void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
562 printWithEnumHandling(parser, *this);
563}
564
565ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
567 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
568}
569
570void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
571 printWithEnumHandling(parser, *this);
572}
573
574//===----------------------------------------------------------------------===//
575// Tosa utilities.
576//===----------------------------------------------------------------------===//
577
578static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
579 if (lhs % rhs != 0)
580 return std::nullopt;
581 return lhs / rhs;
582}
583
585 auto srcType = getElementTypeOrSelf(type);
586 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
587 srcType = getStorageElementTypeFromQuantized(quantType);
588 return srcType;
589}
590
594
595static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
596 Value valZp, StringRef name) {
598 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
599
600 bool bothInts =
601 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
602 bool sameBitWidth =
603 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
604
605 if (!bothInts || !sameBitWidth) {
606 return op->emitOpError()
607 << "expected " << name << " and " << name
608 << "_zp to both be integer of the same bitwidth, but got " << eType
609 << " vs. " << eZpType;
610 }
611 return success();
612}
613
614// Create a pad-const const tensor with value of `val` of required data-type
616 Value src, int32_t val) {
617 const auto srcType = getElementTypeOrSelf(src);
618 const auto srcElemType = getStorageElementTypeOrSelf(src);
619 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
620 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
621 const auto padConstAttr{
622 llvm::isa<FloatType>(srcElemType)
623 ? DenseElementsAttr::get(padConstEType,
624 builder.getFloatAttr(srcElemType, val))
625 : DenseElementsAttr::get(padConstEType,
626 builder.getIntegerAttr(srcElemType, val))};
627 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
628}
629
631 if (dyn_cast<tosa::mxint8Type>(type))
632 return 8;
633 return type.getIntOrFloatBitWidth();
634}
635
636// Update dim size if current dim is dynamic, otherwise raise an error if sizes
637// do not match
638LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
639 const int64_t newDim,
640 const StringRef operandName,
641 const StringRef dimName) {
642 if (ShapedType::isDynamic(currDim)) {
643 currDim = newDim;
644 return success();
645 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
646 return op->emitOpError("expected ")
647 << dimName << " of " << operandName << " to match size " << currDim
648 << ", got " << newDim;
649 }
650 return success();
651}
652
655 auto printDim = [&](int64_t dim) {
656 if (ShapedType::isDynamic(dim))
657 diag << "?";
658 else
659 diag << dim;
660 };
661
662 llvm::interleaveComma(shape, diag, printDim);
663}
664
665static LogicalResult
667 ArrayRef<int64_t> expectedShape,
668 StringRef outputName = "output") {
669 assert(outputType.hasRank() && "expected output type to be ranked");
670
671 if (succeeded(verifyCompatibleShape(outputType.getShape(), expectedShape)))
672 return success();
673
674 InFlightDiagnostic diag = op->emitOpError("expected ");
675 diag << outputName << " shape ";
676 printShapeToDiagnostic(diag, outputType.getShape());
677 diag << " to be compatible with inferred shape ";
678 printShapeToDiagnostic(diag, expectedShape);
679 return diag;
680}
681
683 Operation *op, const int64_t inputSize, const int64_t kernelSize,
684 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
685 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
686 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
687 const llvm::StringRef padAfterName) {
688 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
689 return success();
690
691 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
692
693 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
694 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
695 stride);
696 if (!calculatedOutSizeMinusOne.has_value())
697 return op->emitOpError("expected input_")
698 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
699 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
700 << dimAxis << " to be wholly divisible by stride_" << dimAxis
701 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
702 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
703 << ") / " << stride;
704
705 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
706 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
707 return op->emitOpError("calculated output ")
708 << dimName << " did not match expected: "
709 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
710
711 return success();
712}
713
714//===----------------------------------------------------------------------===//
715// mxint8Type DenseElementTypeInterface implementation.
716//===----------------------------------------------------------------------===//
717size_t mlir::tosa::mxint8Type::getDenseElementBitSize() const { return 8; }
718
720mlir::tosa::mxint8Type::convertToAttribute(ArrayRef<char> rawData) const {
721 assert(rawData.size() == 1 && "expected 1 byte for tosa.mxint8 element");
722 const auto intType = IntegerType::get(getContext(), 8);
723 return intType.convertToAttribute(rawData);
724}
725
726LogicalResult mlir::tosa::mxint8Type::convertFromAttribute(
728 const auto intAttr = dyn_cast<IntegerAttr>(attr);
729 if (!intAttr)
730 return failure();
731 const Type attrType = intAttr.getType();
732 if (!attrType.isSignlessInteger(8))
733 return failure();
734 return cast<IntegerType>(attrType).convertFromAttribute(attr, result);
735}
736
737//===----------------------------------------------------------------------===//
738// TOSA Operator Verifiers.
739//===----------------------------------------------------------------------===//
740
741template <typename T>
742static LogicalResult verifyConvOp(T op) {
743 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
744 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
745
746 auto inputEType = inputType.getElementType();
747 auto weightEType = weightType.getElementType();
748 auto biasEType =
749 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
750 auto resultEType =
751 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
752 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
753 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
754
755 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
756 inputEType = getStorageElementTypeFromQuantized(quantType);
757
758 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
759 weightEType = getStorageElementTypeFromQuantized(quantType);
760
761 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
762 biasEType = getStorageElementTypeFromQuantized(quantType);
763
764 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
765 resultEType = getStorageElementTypeFromQuantized(quantType);
766
767 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
768 // for now, only enforce bias element type == result element type for
769 // float types.
770 op.emitOpError(
771 "expect both bias and result to have same element type, got ")
772 << biasEType << " and " << resultEType;
773 return failure();
774 }
775
776 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
777 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
778 if (inputEType != weightEType) {
779 op.emitOpError(
780 "expect both input and weight to have same element type, got ")
781 << inputEType << " and " << weightEType;
782 return failure();
783 }
784 }
785
786 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
787 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
788
789 // Either both must be float or both non-float.
790 if (inputIsFloat != weightIsFloat) {
791 op.emitOpError(
792 "expect both input and weight to be float or not together, got ")
793 << inputEType << " and " << weightEType;
794 return failure();
795 }
796
797 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
798 if (inputEType != inputZpEType) {
799 return op.emitOpError("expect both input and its zero point are the same "
800 "element type, got ")
801 << inputEType << " and " << inputZpEType;
802 }
803
804 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
805 if (weightEType != weightZpEType) {
806 return op.emitOpError("expect both weight and its zero point are the same "
807 "element type, got ")
808 << weightEType << " and " << weightZpEType;
809 }
810
811 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
812 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
813 return failure();
814
815 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
816 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
817 return failure();
818
819 return success();
820}
821
822LogicalResult tosa::ConstOp::verify() {
823
824 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
825 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
826
827 if (!attrType || !outputType) {
828 emitOpError("expected tensors for attr/result type");
829 return failure();
830 }
831
832 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
833 outputType.getElementType())) {
834 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
835 return success();
836 }
837
838 if (attrType.getElementType() != outputType.getElementType()) {
839 emitOpError("expected same attr/result element types");
840 return failure();
841 }
842
843 return success();
844}
845
846template <typename T>
847static LogicalResult verifyConvOpModes(T op) {
848 auto inputEType =
849 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
850
851 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
852 inputEType = getStorageElementTypeFromQuantized(quantType);
853
854 auto accType = op.getAccType();
855 if (inputEType.isInteger(8) && !accType.isInteger(32))
856 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
857 << accType;
858
859 if (inputEType.isInteger(16) && !accType.isInteger(48))
860 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
861 << accType;
862
863 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
864 !(accType.isF16() || accType.isF32()))
865 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
866 << accType;
867
868 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
869 return op.emitOpError(
870 "accumulator type for f16 tensor is not f16/f32, got ")
871 << accType;
872
873 if (inputEType.isBF16() && !accType.isF32())
874 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
875 << accType;
876
877 if (inputEType.isF32() && !accType.isF32())
878 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
879 << accType;
880
881 auto resultEType =
882 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
883
884 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
885 resultEType = getStorageElementTypeFromQuantized(quantType);
886
887 return success();
888}
889
890//===----------------------------------------------------------------------===//
891// ERROR_IF functions.
892// ERROR_IF is a predicate that must set an error if the condition holds.
893//===----------------------------------------------------------------------===//
894
895template <typename T>
896static LogicalResult verifyConvOpErrorIf(T op) {
897 llvm::ArrayRef<int64_t> padding = op.getPad();
898 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
899 return op.emitOpError("expect all padding values to be >= 0, got ")
900 << padding;
901
902 llvm::ArrayRef<int64_t> strides = op.getStride();
903 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
904 return op.emitOpError("expect all stride values to be >= 1, got ")
905 << strides;
906
907 llvm::ArrayRef<int64_t> dilations = op.getDilation();
908 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
909 return op.emitOpError("expect all dilation values to be >= 1, got ")
910 << dilations;
911
912 const RankedTensorType outputType =
913 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
914 if (!outputType)
915 // Skip following checks if output is not ranked
916 return success();
917
918 const RankedTensorType inputType =
919 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
920 const RankedTensorType weightType =
921 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
922
923 if (inputType && weightType) {
924 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
925 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
926 if (failed(verifyConvOutputSize(
927 op, inputType.getDimSize(1), weightType.getDimSize(1),
928 outputType.getDimSize(1), padding[0], padding[1], strides[0],
929 dilations[0], "height", "y", "top", "bottom")))
930 return failure();
931
932 if (failed(verifyConvOutputSize(
933 op, inputType.getDimSize(2), weightType.getDimSize(2),
934 outputType.getDimSize(2), padding[2], padding[3], strides[1],
935 dilations[1], "width", "x", "left", "right")))
936 return failure();
937 }
938
939 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
940 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
941 if (failed(verifyConvOutputSize(
942 op, inputType.getDimSize(1), weightType.getDimSize(0),
943 outputType.getDimSize(1), padding[0], padding[1], strides[0],
944 dilations[0], "height", "y", "top", "bottom")))
945 return failure();
946
947 if (failed(verifyConvOutputSize(
948 op, inputType.getDimSize(2), weightType.getDimSize(1),
949 outputType.getDimSize(2), padding[2], padding[3], strides[1],
950 dilations[1], "width", "x", "left", "right")))
951 return failure();
952 }
953
954 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
955 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
956 if (failed(verifyConvOutputSize(
957 op, inputType.getDimSize(1), weightType.getDimSize(1),
958 outputType.getDimSize(1), padding[0], padding[1], strides[0],
959 dilations[0], "depth", "d", "front", "back")))
960 return failure();
961
962 if (failed(verifyConvOutputSize(
963 op, inputType.getDimSize(2), weightType.getDimSize(2),
964 outputType.getDimSize(2), padding[2], padding[3], strides[1],
965 dilations[1], "height", "y", "top", "bottom")))
966 return failure();
967
968 if (failed(verifyConvOutputSize(
969 op, inputType.getDimSize(3), weightType.getDimSize(3),
970 outputType.getDimSize(3), padding[4], padding[5], strides[2],
971 dilations[2], "width", "x", "left", "right")))
972 return failure();
973 }
974 }
975
976 const RankedTensorType biasType =
977 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
978 if (!biasType)
979 // Skip following checks if bias is not ranked
980 return success();
981
982 const int64_t biasChannels = biasType.getDimSize(0);
983 const int64_t outputChannels =
984 outputType.getDimSize(outputType.getRank() - 1);
985 if (biasChannels == ShapedType::kDynamic ||
986 outputChannels == ShapedType::kDynamic)
987 // Skip following checks if biasChannels or outputChannels is dynamic dim
988 return success();
989
990 if (biasChannels != outputChannels && biasChannels != 1)
991 return op.emitOpError(
992 "bias channels expected to be equal to output channels (")
993 << outputChannels << ") or 1, got " << biasChannels;
994
995 return success();
996}
997
998// Verify whether same type and shape of the given two types.
999static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
1000 StringRef name1, Type type2,
1001 StringRef name2) {
1002 auto shapeType1 = dyn_cast<ShapedType>(type1);
1003 auto shapeType2 = dyn_cast<ShapedType>(type2);
1004 if (!shapeType1 || !shapeType2)
1005 return failure();
1006
1007 auto elemType1 = shapeType1.getElementType();
1008 auto elemType2 = shapeType2.getElementType();
1009 if (elemType1 != elemType2)
1010 return op->emitOpError()
1011 << "require same element type for " << name1 << " (" << elemType1
1012 << ") and " << name2 << " (" << elemType2 << ")";
1013
1014 if (failed(verifyCompatibleShape(type1, type2)))
1015 return op->emitOpError()
1016 << "require same shapes for " << name1 << " (" << type1 << ") and "
1017 << name2 << " (" << type2 << ")";
1018
1019 return success();
1020}
1021
1022// Verify whether same length, type, and shape of the given two tensor lists.
1023static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
1024 StringRef name1,
1025 ValueRange list2,
1026 StringRef name2) {
1027 if (list1.size() != list2.size())
1028 return op->emitOpError()
1029 << "require same number of values in " << name1 << " ("
1030 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
1031
1032 for (auto [type1, type2] :
1033 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
1034 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
1035 return failure();
1036 }
1037
1038 return success();
1039}
1040
1041static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
1042 ShapeAdaptor shapeAdaptor(type);
1043 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
1044 return success();
1045
1046 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
1047}
1048
1049template <typename T>
1050static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
1051 Operation *symTableOp =
1052 op->template getParentWithTrait<OpTrait::SymbolTable>();
1053 if (!symTableOp)
1054 // If the operation is not the scope of a symbol table, we cannot
1055 // verify it against it's declaration.
1056 return success();
1057
1058 SymbolTable symTable(symTableOp);
1059 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
1060
1061 // Verify prior declaration
1062 if (!varOp)
1063 return op->emitOpError("'")
1064 << op.getName() << "' has not been declared by 'tosa.variable'";
1065
1066 // Verify type and shape
1067 auto variableType = getVariableType(varOp);
1068 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1069 "the input tensor")
1070 .failed())
1071 return failure();
1072 return success();
1073}
1074
1075// verify that inType and outType have same element types
1076template <typename T>
1077static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1078 StringRef aName = "input",
1079 StringRef bName = "output") {
1080 auto aTType = llvm::dyn_cast<TensorType>(aType);
1081 auto bTType = llvm::dyn_cast<TensorType>(bType);
1082 if (!aTType) {
1083 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1084 return failure();
1085 }
1086 if (!bTType) {
1087 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1088 return failure();
1089 }
1090 auto aElementType = aTType.getElementType();
1091 auto bElementType = bTType.getElementType();
1092 auto aQuantType =
1093 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1094 auto bQuantType =
1095 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1096 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1097 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1098 aElementType != bElementType) {
1099 // only check if both element types are int/index/float/UniformQuantized
1100 // eg, not sure how to check quant::QuantizedType
1101 // this happens in test_conv2d_q_grouped_convolution in
1102 // tfl-to-tosa-pipeline.mlir
1103 op.emitOpError("expect ")
1104 << aName << " and " << bName << " to have same element type, got "
1105 << aElementType << " and " << bElementType;
1106 return failure();
1107 }
1108 return success();
1109}
1110
1111LogicalResult tosa::ArgMaxOp::verify() {
1112 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1113
1114 // Ensure output is of 32-bit integer
1115 if (const auto resultETy = resultType.getElementType();
1116 !resultETy.isIntOrIndex())
1117 return emitOpError("result tensor is not of integer type");
1118
1119 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1120 if (!inputType.hasRank())
1121 return success();
1122
1123 // Ensure axis is within the tensor rank
1124 const int64_t axis = getAxisAttr().getInt();
1125 if (((axis < 0) || axis >= inputType.getRank()))
1126 return emitOpError("specified axis is outside the rank of the tensor");
1127
1128 if (!resultType.hasRank())
1129 return success();
1130
1131 const ArrayRef<int64_t> inputShape = inputType.getShape();
1132 const ArrayRef<int64_t> outputShape = resultType.getShape();
1133 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1134 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1135 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1136 return emitOpError("expected output shape '")
1137 << expectedOutputShape << "', got '" << outputShape << "'";
1138
1139 return success();
1140}
1141
1142static LogicalResult verifyPoolingOpImpl(Operation *op,
1143 ArrayRef<int64_t> kernel,
1144 ArrayRef<int64_t> strides,
1145 ArrayRef<int64_t> padding, Value input,
1146 Value output) {
1147 const bool hasKernel = kernel.size() > 0;
1148 const bool hasStrides = strides.size() > 0;
1149 const bool hasPad = padding.size() > 0;
1150
1151 if (hasKernel && llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1152 return op->emitOpError("expect all kernel values to be >= 1, got ")
1153 << kernel;
1154
1155 if (hasStrides && llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1156 return op->emitOpError("expect all stride values to be >= 1, got ")
1157 << strides;
1158
1159 if (hasPad && llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1160 return op->emitOpError("expect all padding values to be >= 0, got ")
1161 << padding;
1162
1163 if (hasKernel && hasPad) {
1164 // Padding must be less than kernel size to avoid a divide-by-zero
1165 const int64_t kernelX = kernel[1];
1166 const int64_t padLeft = padding[2];
1167 const int64_t padRight = padding[3];
1168 if (padRight >= kernelX || padLeft >= kernelX)
1169 return op->emitOpError("expected left/right padding to be less than the "
1170 "width of the kernel, got pad_left=")
1171 << padLeft << ", pad_right=" << padRight
1172 << ", kernel_x=" << kernelX;
1173
1174 const int64_t kernelY = kernel[0];
1175 const int64_t padTop = padding[0];
1176 const int64_t padBottom = padding[1];
1177 if (padTop >= kernelY || padBottom >= kernelY)
1178 return op->emitOpError("expected top/bottom padding to be less than the "
1179 "height of the kernel, got pad_top=")
1180 << padTop << ", pad_bottom=" << padBottom
1181 << ", kernel_y=" << kernelY;
1182 }
1183
1184 const auto inputType = llvm::dyn_cast<RankedTensorType>(input.getType());
1185 const auto outputType = llvm::dyn_cast<RankedTensorType>(output.getType());
1186 if (!inputType || !outputType)
1187 return success();
1188
1189 if (hasKernel && hasStrides && hasPad) {
1190 const auto verifyOutputSize =
1191 [op](const int64_t inputSize, const int64_t outputSize,
1192 const int64_t kernelSize, const int64_t strideSize,
1193 const int64_t padBefore, const int64_t padAfter,
1194 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1195 const llvm::StringRef padBeforeName,
1196 const llvm::StringRef padAfterName) -> LogicalResult {
1197 if (ShapedType::isDynamic(inputSize))
1198 return success();
1199
1200 const std::optional<int64_t> calculatedOutSizeMinusOne =
1201 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1202 if (!calculatedOutSizeMinusOne.has_value())
1203 return op->emitOpError("expected input_")
1204 << dimName << " + pad_" << padBeforeName << " + pad_"
1205 << padAfterName << " - kernel_" << dimAxis
1206 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1207 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1208 << kernelSize << ") / " << strideSize;
1209
1210 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1211 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1212 return op->emitOpError("calculated output ")
1213 << dimName << " did not match expected: " << "calculated="
1214 << calculatedOutSize << ", expected=" << outputSize;
1215
1216 return success();
1217 };
1218
1219 if (failed(verifyOutputSize(inputType.getDimSize(1),
1220 outputType.getDimSize(1), kernel[0], strides[0],
1221 padding[0], padding[1], "height", "y", "top",
1222 "bottom")))
1223 return failure();
1224
1225 if (failed(verifyOutputSize(
1226 inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
1227 strides[1], padding[2], padding[3], "width", "x", "left", "right")))
1228 return failure();
1229 }
1230 return success();
1231}
1232
1233template <typename T>
1234static LogicalResult verifyPoolingOp(T op) {
1235 return verifyPoolingOpImpl(op.getOperation(), op.getKernel(), op.getStride(),
1236 op.getPad(), op.getInput(), op.getOutput());
1237}
1238
1239template <typename T>
1240static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op) {
1241 const Type inputETy = getStorageElementTypeOrSelf(op.getInput().getType());
1242 const Type resultETy = getStorageElementTypeOrSelf(op.getOutput().getType());
1243 const Type inputZpETy =
1244 getStorageElementTypeOrSelf(op.getInputZp().getType());
1245 const Type outputZpETy =
1246 getStorageElementTypeOrSelf(op.getOutputZp().getType());
1247
1248 auto accType = op.getAccType();
1249 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1250 return op.emitOpError("accumulator type for integer tensor is not i32");
1251
1252 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1253 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
1254
1255 if (inputETy.isBF16() && !accType.isF32())
1256 return op.emitOpError("accumulator type for bf16 tensor is not f32");
1257
1258 if (inputETy.isF32() && !accType.isF32())
1259 return op.emitOpError("accumulator type for f32 tensor is not f32");
1260
1261 if (inputETy != inputZpETy)
1262 return op.emitOpError("expect both input and its zero point are the same "
1263 "element type, got ")
1264 << inputETy << " and " << inputZpETy;
1265
1266 if (resultETy != outputZpETy)
1267 return op.emitOpError("expect both output and its zero point are the same "
1268 "element type, got ")
1269 << resultETy << " and " << outputZpETy;
1270
1271 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1272 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
1273 return failure();
1274
1275 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1276 if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
1277 return failure();
1278
1279 return success();
1280}
1281
1282namespace {
1283struct AdaptivePoolingConstShapeValues {
1284 llvm::SmallVector<int64_t> kernel;
1285 llvm::SmallVector<int64_t> stride;
1286 llvm::SmallVector<int64_t> pad;
1287};
1288} // namespace
1289
1290template <typename T>
1292 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
1293 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
1294
1295template <typename T,
1296 typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
1297 int>::type = 0>
1299 T op, AdaptivePoolingConstShapeValues &values) {
1300 tosa::getConstShapeValues(op.getKernel().getDefiningOp(), values.kernel);
1301 tosa::getConstShapeValues(op.getStride().getDefiningOp(), values.stride);
1302 tosa::getConstShapeValues(op.getPad().getDefiningOp(), values.pad);
1303}
1304
1305LogicalResult tosa::AvgPool2dOp::verify() {
1306 if (failed(verifyPoolingOp(*this)))
1307 return failure();
1309 return failure();
1310 return success();
1311}
1312
1313LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
1314 AdaptivePoolingConstShapeValues values;
1316
1317 // If pad/stride/kernel are not constant, this is okay, we just can't check
1318 // their values. extractAdaptivePoolingConstShapeOperands will return an empty
1319 // list for each non CTC input. verifyPoolingOpImpl will need to handle values
1320 // not being present, and return success if they cannot be checked.
1321
1322 if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
1323 values.pad, getInput(), getOutput())))
1324 return failure();
1325
1327 return failure();
1328
1329 return success();
1330}
1331
1332LogicalResult tosa::ClampOp::verify() {
1333 mlir::Type inputETy =
1334 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1335 if (auto quantType =
1336 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1337 inputETy = getStorageElementTypeFromQuantized(quantType);
1338 }
1339 mlir::Type outputETy =
1340 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1341 if (auto quantType =
1342 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1343 outputETy = getStorageElementTypeFromQuantized(quantType);
1344 }
1345 if (inputETy != outputETy)
1346 return emitOpError("input/output element types are incompatible.");
1347
1348 auto maxValAttr = getMaxValAttr();
1349 auto minValAttr = getMinValAttr();
1350
1351 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1352
1353 if (inputETy.isInteger(dataTypeBitWidth)) {
1354 // if input datatype is integer, check that the min_val/max_val attributes
1355 // are integer attributes, and that their type is the same as the input's
1356 // datatype
1357 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1358 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1359 if (!intMaxValAttr || !intMinValAttr ||
1360 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1361 (intMaxValAttr.getType() != inputETy))
1362 return emitOpError("min/max attributes types are incompatible with "
1363 "input/output element types.");
1364
1365 const bool isUnsigned = inputETy.isUnsignedInteger();
1366 const bool isBoolean = inputETy.isInteger(1);
1367 const APInt minVal = intMinValAttr.getValue();
1368 const APInt maxVal = intMaxValAttr.getValue();
1369 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1370 return emitOpError("expected min_val <= max_val, got min_val=")
1371 << minValAttr << ", max_val=" << maxValAttr;
1372 } else {
1373 // otherwise, input datatype is float, check that the min_val/max_val
1374 // attributes share the same type and that their type is the same as the
1375 // input's datatype
1376 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1377 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1378 if (!floatMaxValAttr || !floatMinValAttr ||
1379 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1380 (floatMaxValAttr.getType() != inputETy))
1381 return emitOpError("min/max attributes types are incompatible with "
1382 "input/output element types.");
1383
1384 const APFloat minVal = floatMinValAttr.getValue();
1385 const APFloat maxVal = floatMaxValAttr.getValue();
1386 if (minVal.isNaN() || maxVal.isNaN())
1387 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1388 << minValAttr << ", max_val=" << maxValAttr;
1389
1390 if (maxVal < minVal)
1391 return emitOpError("expected min_val <= max_val, got min_val=")
1392 << minValAttr << ", max_val=" << maxValAttr;
1393 }
1394
1395 return success();
1396}
1397
1398//===----------------------------------------------------------------------===//
1399// TOSA Operator Quantization Builders.
1400//===----------------------------------------------------------------------===//
1401
1402/// This builder is called on all convolution operators except TransposeConv,
1403/// which has specialized output shape semantics. The builder also defines the
1404/// bitwidth of the output given the bit width of the input & weight content.
1406 Type outputType, Value input, Value weight,
1407 Value bias, DenseI64ArrayAttr pad,
1408 DenseI64ArrayAttr stride,
1409 DenseI64ArrayAttr dilation,
1410 TypeAttr accType) {
1411 auto zps = createZPsAsConst(builder, input, weight);
1412 result.addOperands({input, weight, bias, zps.first, zps.second});
1413 result.addAttribute("pad", pad);
1414 result.addAttribute("stride", stride);
1415 result.addAttribute("dilation", dilation);
1416 result.addAttribute("acc_type", accType);
1417 Type finalOutputType = outputType;
1418 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1419 if (quantAttr) {
1420 finalOutputType =
1421 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1422 }
1423 result.addTypes(finalOutputType);
1424}
1425
1426/// Handles tosa.transpose_conv2d which has outpad and output shape
1427/// attributes.
1428static void
1430 Type outputType, Value input, Value weight,
1431 Value bias, DenseI64ArrayAttr outpad,
1432 DenseI64ArrayAttr stride, TypeAttr accType) {
1433 auto zps = createZPsAsConst(builder, input, weight);
1434 result.addOperands({input, weight, bias, zps.first, zps.second});
1435 result.addAttribute("out_pad", outpad);
1436 result.addAttribute("stride", stride);
1437 result.addAttribute("acc_type", accType);
1438 Type finalOutputType = outputType;
1439 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1440 if (quantAttr) {
1441 finalOutputType =
1442 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1443 }
1444 result.addTypes(finalOutputType);
1445}
1446
1447/// The tosa.matmul op is also intended to be generated where a fully_connected
1448/// op must be constructed where the weight is not a constant. In this case,
1449/// the fully_connected op must be expressed using matmul.
1450/// TODO: Add link to the leglization document explaining this.
1452 OperationState &result, Type outputType,
1453 Value a, Value b) {
1454 auto zps = createZPsAsConst(builder, a, b);
1455 result.addOperands({a, b, zps.first, zps.second});
1456
1457 Type finalOutputType{outputType};
1458 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1459 auto eType = getStorageElementTypeOrSelf(a.getType());
1460 auto inputBits = eType.getIntOrFloatBitWidth();
1461
1462 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1463 assert(outputShapedType && "Output must be a shaped type");
1464
1465 IntegerType accElementType;
1466 if (inputBits == 16)
1467 accElementType = builder.getIntegerType(48);
1468 else
1469 accElementType = builder.getI32Type();
1470
1471 finalOutputType = outputShapedType.clone(accElementType);
1472 }
1473 result.addTypes(finalOutputType);
1474}
1475
1476/// Both the tosa.avg_pool2d and unary ops use the same
1477/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1478/// has additional parameters not part of the unary ops.
1479static void
1481 Type outputType, Value input,
1482 DenseArrayAttr kernel, DenseArrayAttr stride,
1483 DenseArrayAttr pad, TypeAttr accType) {
1484 const Location loc{result.location};
1485 int64_t inputZp{0};
1486 int64_t outputZp{0};
1487
1488 if (auto quantAttr =
1489 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1490 inputZp = quantAttr.getInputZp();
1491 outputZp = quantAttr.getOutputZp();
1492 }
1493 const std::optional<Value> inputZpOp =
1494 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1495 if (!inputZpOp) {
1496 (void)emitError(
1497 loc,
1498 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1499 }
1500 const std::optional<Value> outputZpOp =
1501 createZeroPointTensor(builder, loc, outputType, outputZp);
1502 if (!outputZpOp) {
1503 (void)emitError(loc, "Failed to create output zero point tensor for "
1504 "quantized AVG_POOL2D op");
1505 }
1506
1507 if (inputZpOp && outputZpOp) {
1508 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1509 } else {
1510 // failed to create one or more zero points above: just add input as
1511 // operands this will trigger error in building the op because of missing
1512 // zero points
1513 result.addOperands({input});
1514 }
1515 result.addAttribute("kernel", kernel);
1516 result.addAttribute("stride", stride);
1517 result.addAttribute("pad", pad);
1518 result.addAttribute("acc_type", accType);
1519 result.types.push_back(outputType);
1520}
1521
1522/// This builder mirrors avg_pool2d quant-info handling and materializes
1523/// kernel/stride/pad as const_shape operands for avg_pool2d_adaptive.
1525 OpBuilder &builder, OperationState &result, Type outputType, Value input,
1527 TypeAttr accType) {
1528 const Location loc{result.location};
1529 int64_t inputZp{0};
1530 int64_t outputZp{0};
1531
1532 if (auto quantAttr =
1533 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1534 inputZp = quantAttr.getInputZp();
1535 outputZp = quantAttr.getOutputZp();
1536 }
1537 const std::optional<Value> inputZpOp =
1538 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1539 if (!inputZpOp) {
1540 (void)emitError(loc,
1541 "Failed to create input zero point tensor for quantized "
1542 "AVG_POOL2D_ADAPTIVE op");
1543 }
1544 const std::optional<Value> outputZpOp =
1545 createZeroPointTensor(builder, loc, outputType, outputZp);
1546 if (!outputZpOp) {
1547 (void)emitError(loc, "Failed to create output zero point tensor for "
1548 "quantized AVG_POOL2D_ADAPTIVE op");
1549 }
1550
1551 if (inputZpOp && outputZpOp) {
1552 ImplicitLocOpBuilder b(loc, builder);
1553 Value kernelShape = getTosaConstShape(b, kernel.asArrayRef());
1554 Value strideShape = getTosaConstShape(b, stride.asArrayRef());
1555 Value padShape = getTosaConstShape(b, pad.asArrayRef());
1556 result.addOperands({input, inputZpOp.value(), outputZpOp.value(),
1557 kernelShape, strideShape, padShape});
1558 } else {
1559 // Failed to create one or more zero points above: just add input as
1560 // operands. This will trigger error in building the op because of missing
1561 // operands.
1562 result.addOperands({input});
1563 }
1564 result.addAttribute("acc_type", accType);
1565 result.types.push_back(outputType);
1566}
1567
1568/// This builder is called on single-parameter negate operator
1569/// to construct input and output zero points based on their
1570/// types.
1572 OperationState &result, Type outputType,
1573 Value input) {
1574 const Location loc{result.location};
1575 int64_t input1Zp{0};
1576 int64_t outputZp{0};
1577 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1578 if (quantAttr) {
1579 input1Zp = quantAttr.getInputZp();
1580 outputZp = quantAttr.getOutputZp();
1581 }
1582 const std::optional<Value> input1ZpOp =
1583 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1584 if (!input1ZpOp) {
1585 (void)emitError(
1586 loc, "Failed to create input1 zero point for quantized NEGATE op");
1587 }
1588
1589 const std::optional<Value> outputZpOp =
1590 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1591 if (!outputZpOp) {
1592 (void)emitError(
1593 loc, "Failed to create output zero point for quantized NEGATE op");
1594 }
1595
1596 if (input1ZpOp && outputZpOp) {
1597 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1598 } else {
1599 // failed to create one or more zero points above: just add input as
1600 // operands. This will trigger error in building the op because of
1601 // missing zero points
1602 result.addOperands({input});
1603 }
1604
1605 result.types.push_back(outputType);
1606}
1607
1608/// This builder is called on TOSA pad operator that needs to create its own
1609/// OptionalAttr quantization_attr parameter to scale the padding values
1610/// correctly. No pad_const is interpreted as zero-padding.
1612 Type outputType, Value input,
1613 Value paddings) {
1614 const Location loc{result.location};
1615 int32_t zp{0};
1616 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1617 if (quantAttr) {
1618 zp = static_cast<int32_t>(quantAttr.getInputZp());
1619 }
1620 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1621 result.addOperands({input, paddings, padConstOp});
1622 result.types.push_back(outputType);
1623}
1624
1626 StringRef name, Type variableType,
1627 Attribute initialValue) {
1628 const Location loc{result.location};
1629 auto nameAttr = builder.getStringAttr(name);
1630
1631 auto shapedType = dyn_cast<ShapedType>(variableType);
1632 if (!shapedType) {
1633 (void)emitError(loc, "variable type must be a shaped type");
1634 return;
1635 }
1636 if (!shapedType.hasRank()) {
1637 (void)emitError(loc, "variable type must be a ranked type");
1638 return;
1639 }
1640
1641 auto elementType = shapedType.getElementType();
1642 auto elementTypeAttr = TypeAttr::get(elementType);
1643 ArrayRef<int64_t> shape = shapedType.getShape();
1644 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1645
1646 result.addAttribute("sym_name", nameAttr);
1647 result.addAttribute("var_shape", varShapeAttr);
1648 result.addAttribute("type", elementTypeAttr);
1649 result.addAttribute("initial_value", initialValue);
1650}
1651
1652//===----------------------------------------------------------------------===//
1653// TOSA Operator Return Type Inference.
1654//===----------------------------------------------------------------------===//
1655static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
1656 const int64_t dim2) {
1657 if (dim1 == 1)
1658 return dim2;
1659 if (dim2 == 1)
1660 return dim1;
1661
1662 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1663 return failure();
1664
1665 // Prefer static dimension over dynamic
1666 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1667}
1668
1669static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1670 SmallVector<int64_t> &outShape) {
1671 int64_t outRank = 0;
1672 for (int i = 0, e = operands.size(); i != e; ++i) {
1673 auto shape = operands.getShape(i);
1674 if (!shape.hasRank()) {
1675 // TODO(jennik): Update function to have better case handling for
1676 // invalid operands and for ranked tensors.
1677 return failure();
1678 }
1679 outRank = std::max<int64_t>(outRank, shape.getRank());
1680 }
1681
1682 outShape.resize(outRank, 1);
1683
1684 for (int i = 0, e = operands.size(); i != e; ++i) {
1685 auto shape = operands.getShape(i);
1686 auto rankDiff = outShape.size() - shape.getRank();
1687
1688 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1689 auto dim1 = outShape[i + rankDiff];
1690 auto dim2 = shape.getDimSize(i);
1691
1692 const FailureOr<int64_t> maybeResolvedDim =
1693 resolveBroadcastDim(dim1, dim2);
1694 if (failed(maybeResolvedDim))
1695 return failure();
1696 const int64_t resolvedDim = *maybeResolvedDim;
1697 outShape[i + rankDiff] = resolvedDim;
1698 }
1699 }
1700
1701 return success();
1702}
1703
1704LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1705 MLIRContext *context, ::std::optional<Location> location,
1706 ArgMaxOp::Adaptor adaptor,
1707 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1708 ShapeAdaptor inputShape(adaptor.getInput().getType());
1709 IntegerAttr axis = adaptor.getProperties().axis;
1710 int32_t axisVal = axis.getValue().getSExtValue();
1711
1712 if (!inputShape.hasRank()) {
1713 inferredReturnShapes.push_back(ShapedTypeComponents());
1714 return success();
1715 }
1716
1717 SmallVector<int64_t> outShape;
1718 outShape.reserve(inputShape.getRank() - 1);
1719 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1720 if (i == axisVal)
1721 continue;
1722 outShape.push_back(inputShape.getDimSize(i));
1723 }
1724
1725 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1726 return success();
1727}
1728
1729LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1730 MLIRContext *context, ::std::optional<Location> location,
1731 RFFT2dOp::Adaptor adaptor,
1732 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1733 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1734
1735 if (!inputShape.hasRank())
1736 return failure();
1737
1738 llvm::SmallVector<int64_t> outputShape;
1739 outputShape.resize(3, ShapedType::kDynamic);
1740 outputShape[0] = inputShape.getDimSize(0);
1741 outputShape[1] = inputShape.getDimSize(1);
1742 int64_t inWidth = inputShape.getDimSize(2);
1743
1744 // Note that we can support this calculation symbolically
1745 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1746 if (inWidth != ShapedType::kDynamic)
1747 outputShape[2] = inWidth / 2 + 1;
1748
1749 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1750 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1751
1752 return success();
1753}
1754
1755static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1756 const llvm::StringRef dimName) {
1757 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1758 if (!isPowerOfTwo)
1759 return op->emitOpError("expected ")
1760 << dimName << " to be a power of two, got " << dimSize;
1761
1762 return success();
1763}
1764
1765LogicalResult tosa::RFFT2dOp::verify() {
1766 const auto outputTypes = getResultTypes();
1767 if (failed(verifyCompatibleShapes(outputTypes)))
1768 return emitOpError("expected output shapes to match, got ") << outputTypes;
1769
1770 const auto inputType =
1771 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1772 if (!inputType)
1773 return success();
1774
1775 const int64_t height = inputType.getDimSize(1);
1776 if (ShapedType::isStatic(height) &&
1777 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1778 return failure();
1779
1780 const int64_t width = inputType.getDimSize(2);
1781 if (ShapedType::isStatic(width) &&
1782 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1783 return failure();
1784
1785 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1786 if (!outputType)
1787 return success();
1788
1789 // Batch and height input/output dimensions should match
1790 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1791 outputType.getShape().drop_back())))
1792 return emitOpError("expected batch and height dimensions of input/output "
1793 "to match, got input=")
1794 << inputType << " output=" << outputType;
1795
1796 // Output width dimension expected to be input_width / 2 + 1
1797 const int64_t outputWidth = outputType.getDimSize(2);
1798 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1799 (outputWidth != (width / 2) + 1))
1800 return emitOpError(
1801 "expected output width to be equal to input_width / 2 + 1, got ")
1802 << outputWidth;
1803
1804 return success();
1805}
1806
1807LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1808 MLIRContext *context, ::std::optional<Location> location,
1809 FFT2dOp::Adaptor adaptor,
1810 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1811 inferredReturnShapes.push_back(
1812 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1813 inferredReturnShapes.push_back(
1814 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1815 return success();
1816}
1817
1818LogicalResult tosa::FFT2dOp::verify() {
1819 const auto inputRealType =
1820 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1821 const auto inputImagType =
1822 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1823 if (!inputRealType || !inputImagType)
1824 return success();
1825
1826 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1827 return ShapedType::isDynamic(a) ? a : b;
1828 };
1829
1830 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1831 inputImagType.getDimSize(1));
1832 if (ShapedType::isStatic(height) &&
1833 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1834 return failure();
1835
1836 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1837 inputImagType.getDimSize(2));
1838 if (ShapedType::isStatic(width) &&
1839 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1840 return failure();
1841
1842 return success();
1843}
1844
1845LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1846 MLIRContext *context, ::std::optional<Location> location,
1847 ConcatOp::Adaptor adaptor,
1848 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1849 // Infer all dimension sizes by reducing based on inputs.
1850 const Properties &prop = adaptor.getProperties();
1851 int32_t axis = prop.axis.getValue().getSExtValue();
1852 llvm::SmallVector<int64_t> outputShape;
1853 bool hasRankedInput = false;
1854 for (auto operand : adaptor.getOperands()) {
1855 ShapeAdaptor operandShape(operand.getType());
1856 if (!operandShape.hasRank())
1857 continue;
1858
1859 // Copy the Operand's rank.
1860 if (!hasRankedInput)
1861 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1862
1863 // Copy shapes until the dim is non-dynamic.
1864 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1865 if (i == axis || operandShape.isDynamicDim(i))
1866 continue;
1867 if (outputShape[i] == ShapedType::kDynamic)
1868 outputShape[i] = operandShape.getDimSize(i);
1869 if (outputShape[i] != operandShape.getDimSize(i))
1870 return emitOptionalError(location,
1871 "Cannot concat tensors with different sizes"
1872 " on the non-axis dimension ",
1873 i);
1874 }
1875
1876 hasRankedInput = true;
1877 }
1878
1879 if (adaptor.getInput1().empty())
1880 return failure();
1881
1882 Type inputType =
1883 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1884 if (!hasRankedInput) {
1885 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1886 return success();
1887 }
1888
1889 // Determine the dimension size along the concatenation axis.
1890 int64_t concatDimSize = 0;
1891 for (auto operand : adaptor.getOperands()) {
1892 ShapeAdaptor operandShape(operand.getType());
1893
1894 // We need to know the length of the concatenation axis of all inputs to
1895 // determine the dimension size of the output shape.
1896 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1897 concatDimSize = ShapedType::kDynamic;
1898 break;
1899 }
1900
1901 concatDimSize += operandShape.getDimSize(axis);
1902 }
1903
1904 outputShape[axis] = concatDimSize;
1905
1906 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1907 return success();
1908}
1909
1910LogicalResult tosa::ConcatOp::verify() {
1911 // check that each input has same element type as output
1912 auto outType = getOutput().getType();
1913 const Operation::operand_range inputList = getInput1();
1914
1915 // Check there is at least one input
1916 if (inputList.empty())
1917 return emitOpError("expect at least one input");
1918
1919 if (!llvm::all_of(inputList, [&](auto input) {
1920 return succeeded(verifySameElementTypes(
1921 *this, /* inType = */ input.getType(), outType));
1922 })) {
1923 return failure();
1924 }
1925
1926 const int32_t axis = getAxis();
1927 ShapeAdaptor firstRankedInputShape = nullptr;
1928 for (const auto &input : inputList) {
1929 const Type inputType = input.getType();
1930 ShapeAdaptor currShape(inputType);
1931 if (currShape.hasRank()) {
1932 firstRankedInputShape = currShape;
1933 // Check axis is in expected range
1934 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1935 return emitOpError("expect axis to be within range 0 < axis < "
1936 "rank(input1[firstRankedTensorIdx]), got ")
1937 << axis;
1938 break;
1939 }
1940 }
1941
1942 const auto allOperandsHasRank = [](const Value input) {
1943 return ShapeAdaptor(input.getType()).hasRank();
1944 };
1945 if (llvm::all_of(inputList, allOperandsHasRank)) {
1946 const int64_t firstInputRank = firstRankedInputShape.getRank();
1947
1948 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1949 const ShapeAdaptor inputShape(input.getType());
1950 const int64_t inputRank = inputShape.getRank();
1951 const size_t operandNum = index + 1;
1952
1953 // Check that each operand has the same rank
1954 if (inputRank != firstInputRank)
1955 return emitOpError(
1956 "expect all operands to have the same rank, but got ")
1957 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1958 << operandNum;
1959
1960 // Check non-axis dims match
1961 for (int i = 0; i < inputRank; i++) {
1962 const int64_t inputDim = inputShape.getDimSize(i);
1963 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1964 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1965 inputShape.isDynamicDim(i))
1966 continue;
1967 if (inputDim != firstInputDim)
1968 return emitOpError("expect all operand shapes to have the same sizes "
1969 "on non-axis dimensions, but got ")
1970 << inputDim << " vs " << firstInputDim << " at index " << i
1971 << " on operands 0 and " << operandNum;
1972 }
1973 }
1974
1975 const ShapeAdaptor outputShape(outType);
1976 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1977 return emitOpError("expect output rank to match inputs rank, got ")
1978 << outputShape.getRank() << " vs " << firstInputRank;
1979
1980 // ERROR_IF(axis_sum != shape[axis]);
1981 int64_t axisSum = 0;
1982 for (const auto &input : inputList) {
1983 const ShapeAdaptor inputShape(input.getType());
1984 if (inputShape.isDynamicDim(axis)) {
1985 // make axisSum negative to indicate invalid value
1986 axisSum = -1;
1987 break;
1988 }
1989 axisSum += inputShape.getDimSize(axis);
1990 }
1991
1992 if (axisSum >= 0 && outputShape.hasRank() &&
1993 !outputShape.isDynamicDim(axis) &&
1994 axisSum != outputShape.getDimSize(axis))
1995 return emitOpError("requires sum of axis dimensions of input1 "
1996 "equal to output axis dimension, got ")
1997 << axisSum << " and " << outputShape.getDimSize(axis);
1998 }
1999
2000 return success();
2001}
2002
2003LogicalResult tosa::EqualOp::inferReturnTypeComponents(
2004 MLIRContext *context, ::std::optional<Location> location,
2005 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
2006 RegionRange regions,
2007 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2008 auto elementType = IntegerType::get(context, /*width=*/1);
2009
2011 if (resolveBroadcastShape(operands, outShape).failed()) {
2012 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
2013 return success();
2014 }
2015
2016 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
2017 return success();
2018}
2019
2020bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2021 if (l.size() != r.size() || l.size() != 1)
2022 return false;
2023 return succeeded(verifyCompatibleShape(l[0], r[0]));
2024}
2025
2026LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
2027 MLIRContext *context, ::std::optional<Location> location,
2028 MatMulOp::Adaptor adaptor,
2029 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2030 ShapeAdaptor lhsShape(adaptor.getA().getType());
2031 ShapeAdaptor rhsShape(adaptor.getB().getType());
2032
2033 // All shapes are dynamic.
2034 SmallVector<int64_t> outShape;
2035 outShape.resize(3, ShapedType::kDynamic);
2036
2037 if (lhsShape.hasRank()) {
2038 outShape[0] = lhsShape.getDimSize(0);
2039 outShape[1] = lhsShape.getDimSize(1);
2040 }
2041
2042 if (rhsShape.hasRank()) {
2043 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
2044 : outShape[0];
2045 outShape[2] = rhsShape.getDimSize(2);
2046 }
2047
2048 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2049 return success();
2050}
2051
2052LogicalResult MatMulOp::verify() {
2053 const ShapeAdaptor aShape(getA().getType());
2054 const ShapeAdaptor bShape(getB().getType());
2055 const Type aElementType = aShape.getElementType();
2056 const Type bElementType = bShape.getElementType();
2057
2058 const auto aQuantizedEType =
2059 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
2060 const auto bQuantizedEType =
2061 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
2062
2063 if (aQuantizedEType || bQuantizedEType) {
2064 if (!aQuantizedEType || !bQuantizedEType) {
2065 return emitOpError("expect operands to be both quantized or both not "
2066 "quantized, got ")
2067 << aElementType << " and " << bElementType;
2068 }
2069 // both a and b have quantized element types
2070 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
2071 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
2072 if (aQuantWidth != bQuantWidth) {
2073 return emitOpError("expect quantized operands to have same widths, got ")
2074 << aQuantWidth << " and " << bQuantWidth;
2075 }
2076 }
2077
2078 // check a_zp and b_zp
2079 auto aEType = getStorageElementTypeOrSelf(aElementType);
2080 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
2081 if (aEType != aZpEType)
2082 return emitOpError("expect input a and a_zp have the same "
2083 "element type, got ")
2084 << aEType << " and " << aZpEType;
2085
2086 const Type bEType = getStorageElementTypeOrSelf(bElementType);
2087 const Type bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
2088 if (bEType != bZpEType)
2089 return emitOpError("expect input b and b_zp have the same "
2090 "element type, got ")
2091 << bEType << " and " << bZpEType;
2092
2093 FailureOr<int64_t> maybeAZp = getAZeroPoint();
2094 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
2095 return failure();
2096
2097 FailureOr<int64_t> maybeBZp = getBZeroPoint();
2098 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
2099 return failure();
2100
2101 // Verify input/output shapes
2102 int64_t N = ShapedType::kDynamic;
2103 int64_t H = ShapedType::kDynamic;
2104 int64_t W = ShapedType::kDynamic;
2105 int64_t C = ShapedType::kDynamic;
2106
2107 if (aShape.hasRank()) {
2108 N = aShape.getDimSize(0);
2109 H = aShape.getDimSize(1);
2110 C = aShape.getDimSize(2);
2111 }
2112
2113 if (bShape.hasRank()) {
2114 if (failed(tryUpdateDimOrFailure(*this, N, bShape.getDimSize(0), "b",
2115 "batch")) ||
2116 failed(tryUpdateDimOrFailure(*this, C, bShape.getDimSize(1), "b",
2117 "channels")))
2118 return failure();
2119 W = bShape.getDimSize(2);
2120 }
2121
2122 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2123 const auto outputType = cast<ShapedType>(getResult().getType());
2124 if (outputType.hasRank() &&
2125 failed(
2126 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2127 InFlightDiagnostic opError = emitOpError("expected output shape ");
2128 printShapeToDiagnostic(opError, outputType.getShape());
2129 opError << " to be compatible with expected output shape ";
2130 printShapeToDiagnostic(opError, expectedOutputShape);
2131 return opError;
2132 }
2133
2134 return success();
2135}
2136
2137LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
2138 MLIRContext *context, ::std::optional<Location> location,
2139 MatmulTBlockScaledOp::Adaptor adaptor,
2140 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2141 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
2142
2143 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
2144 if (aDataShape.hasRank()) {
2145 outShape[0] = aDataShape.getDimSize(0);
2146 outShape[1] = aDataShape.getDimSize(1);
2147 }
2148
2149 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
2150 if (aScaleShape.hasRank()) {
2151 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
2152 : outShape[0];
2153 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
2154 : outShape[1];
2155 }
2156
2157 // If B batch size is 1, it is broadcast across A's batch size
2158 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
2159 if (bDataShape.hasRank()) {
2160 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
2161 if (bDataBatchSize != 1)
2162 outShape[0] =
2163 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
2164 outShape[2] = bDataShape.getDimSize(1);
2165 }
2166
2167 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
2168 if (bScaleShape.hasRank()) {
2169 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
2170 if (bScaleBatchSize != 1)
2171 outShape[0] =
2172 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
2173 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
2174 : outShape[2];
2175 }
2176
2177 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2178 return success();
2179}
2180
2181LogicalResult MatmulTBlockScaledOp::verify() {
2182 // Verify same input data types
2183 const Type aDataType = getAData().getType();
2184 const Type bDataType = getBData().getType();
2185 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
2186 "B_data")))
2187 return failure();
2188
2189 // Verify input shape compatibility
2190 int64_t N = ShapedType::kDynamic;
2191 int64_t D = ShapedType::kDynamic;
2192 int64_t H = ShapedType::kDynamic;
2193 int64_t W = ShapedType::kDynamic;
2194 int64_t C = ShapedType::kDynamic;
2195 int64_t multiplesOfC = ShapedType::kDynamic;
2196
2197 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
2198 if (aDataShape.hasRank()) {
2199 N = aDataShape.getDimSize(0);
2200 H = aDataShape.getDimSize(1);
2201 C = aDataShape.getDimSize(2);
2202 }
2203
2204 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2205 if (aScaleShape.hasRank()) {
2206 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2207 "a_scale", "batch")) ||
2208 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2209 "a_scale", "height")))
2210 return failure();
2211 multiplesOfC = aScaleShape.getDimSize(2);
2212 }
2213
2214 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2215 if (bDataShape.hasRank()) {
2216 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2217 "b_data", "batch")) ||
2218 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2219 "b_data", "channels")))
2220 return failure();
2221 W = bDataShape.getDimSize(1);
2222 }
2223
2224 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2225 if (bScaleShape.hasRank()) {
2226 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2227 "b_scale", "batch")) ||
2228 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2229 "b_scale", "width")) ||
2230 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2231 bScaleShape.getDimSize(2), "b_scale",
2232 "C/block_size")))
2233 return failure();
2234 }
2235
2236 // Verify batch size is broadcast compatible
2237 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2238 return emitOpError("expect B matrix batch size to be broadcast compatible "
2239 "with A, got D=")
2240 << D << " vs N=" << N;
2241
2242 // Verify C is a multiple of block size
2243 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2244 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2245 return emitOpError("expect block size to be 32, got ") << blockSize;
2246 if (ShapedType::isStatic(C) && C % blockSize != 0)
2247 return emitOpError("expect C to be a multiple of block size, got C=")
2248 << C << ", block_size=" << blockSize;
2249
2250 // Verify multiplesOfC is C / block size
2251 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2252 multiplesOfC != C / blockSize)
2253 return emitOpError(
2254 "expect scale operands dimension 2 to equal C/block_size (")
2255 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2256
2257 // Verify output shape
2258 N = ShapedType::isDynamic(N) ? D : N;
2259 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2260 const auto outputType = cast<ShapedType>(getResult().getType());
2261 if (outputType.hasRank() &&
2262 failed(
2263 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2264 InFlightDiagnostic opError = emitOpError("expected output shape ");
2265 printShapeToDiagnostic(opError, outputType.getShape());
2266 opError << " to be compatible with expected output shape ";
2267 printShapeToDiagnostic(opError, expectedOutputShape);
2268 return opError;
2269 }
2270
2271 return success();
2272}
2273
2274LogicalResult tosa::PadOp::inferReturnTypeComponents(
2275 MLIRContext *context, ::std::optional<Location> location,
2276 PadOp::Adaptor adaptor,
2277 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2278 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2279 auto paddingRank =
2280 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2281 SmallVector<int64_t> outputShape;
2282
2283 // If the input rank is unknown, we can infer the output rank using the
2284 // padding shape's rank divided by 2.
2285 if (!inputShape.hasRank()) {
2286 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2287 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2288 return success();
2289 }
2290
2291 SmallVector<int64_t> paddingValues;
2292 // If the paddings value is not a constant, all dimensions must be dynamic.
2293 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2294 paddingValues)) {
2295 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2296 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2297 return success();
2298 }
2299
2300 outputShape.reserve(inputShape.getRank());
2301 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2302 if (inputShape.isDynamicDim(i)) {
2303 outputShape.push_back(ShapedType::kDynamic);
2304 continue;
2305 }
2306 auto padFront = paddingValues[i * 2];
2307 auto padBack = paddingValues[i * 2 + 1];
2308 if (padFront < 0 || padBack < 0) {
2309 // if either padding for dim i is -1, output dim is unknown
2310 outputShape.push_back(ShapedType::kDynamic);
2311 continue;
2312 }
2313
2314 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2315 }
2316
2317 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2318 return success();
2319}
2320
2321LogicalResult tosa::PadOp::verify() {
2322 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2323 /* outType = */ getOutput().getType())
2324 .failed()) {
2325 return failure();
2326 }
2327
2328 if (auto padConst = getPadConst()) {
2329 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2330 /* outType = */ getOutput().getType())
2331 .failed()) {
2332 return failure();
2333 }
2334 }
2335
2336 RankedTensorType inputType =
2337 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2338 RankedTensorType outputType =
2339 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2340 if (!inputType || !outputType)
2341 return success();
2342
2343 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2344 "output")))
2345 return failure();
2346
2347 auto inputRank = inputType.getRank();
2348 DenseIntElementsAttr paddingAttr;
2349 if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
2350 return success();
2351
2352 auto paddingValues = paddingAttr.getValues<APInt>();
2353 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2354 return emitOpError() << "padding tensor must have " << inputRank
2355 << " * 2 = " << inputRank * 2 << " elements, but got "
2356 << paddingValues.size();
2357
2358 auto inputShape = inputType.getShape();
2359 auto outputShape = outputType.getShape();
2360
2361 for (int64_t i = 0; i < inputRank; ++i) {
2362 int64_t padStart = paddingValues[i * 2].getSExtValue();
2363 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2364
2365 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2366 return emitOpError()
2367 << "invalid padding values at dimension " << i
2368 << ": values must be non-negative or -1 for dynamic padding, got ["
2369 << padStart << ", " << padEnd << "]";
2370 }
2371
2372 // Skip shape verification for dynamic input/output
2373 if (inputShape[i] == ShapedType::kDynamic ||
2374 outputShape[i] == ShapedType::kDynamic)
2375 continue;
2376
2377 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2378 return emitOpError() << "mismatch in output shape at dimension " << i
2379 << ": expected " << inputShape[i] << " + "
2380 << padStart << " + " << padEnd << " = "
2381 << (inputShape[i] + padStart + padEnd)
2382 << ", but got " << outputShape[i];
2383 }
2384 }
2385
2386 return success();
2387}
2388
2389LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2390 MLIRContext *context, ::std::optional<Location> location,
2391 SliceOp::Adaptor adaptor,
2392 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2393
2394 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2397
2398 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2399 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2400 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2401 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2402 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2403 return success();
2404 }
2405
2406 // if size[i] is -1, all remaining elements in dimension i are included
2407 // in the slice, similar to TF.
2408 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2409 // initialize outputShape to all unknown
2410 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2411 if (inputShape.hasRank()) {
2412 for (size_t i = 0; i < size.size(); i++) {
2413 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2414 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2415 start[i] < inputShape.getDimSize(i))) {
2416 // size[i] is not 0 and not < -1, and start[i] is in valid range
2417 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2418 // input shape has unknown dim[i] - only valid if size[i] > 0
2419 if (size[i] > 0) {
2420 outputShape[i] = size[i];
2421 }
2422 } else {
2423 // input shape has known dim[i]
2424 if (size[i] == -1) {
2425 outputShape[i] = inputShape.getDimSize(i) - start[i];
2426 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2427 // start[i] + size[i] is within bound of input shape's dim[i]
2428 outputShape[i] = size[i];
2429 }
2430 }
2431 }
2432 }
2433 } else {
2434 outputShape = convertToMlirShape(size);
2435 }
2436 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2437 return success();
2438}
2439
2440LogicalResult tosa::SliceOp::verify() {
2441 const Value input = getInput1();
2442 const Value output = getOutput();
2443 if (verifySameElementTypes(*this, /* inType = */ input.getType(),
2444 /* outType = */ output.getType())
2445 .failed())
2446 return failure();
2447
2448 const Value start = getStart();
2449 const Value size = getSize();
2450 const ShapeAdaptor inputShape(input.getType());
2451 const ShapeAdaptor outputShape(output.getType());
2452
2453 if (inputShape.hasRank()) {
2454 const auto inputRank = inputShape.getRank();
2455 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2456 return emitOpError(
2457 "expect input1 and output to have the same ranks, got ")
2458 << inputRank << " and " << outputShape.getRank();
2459
2460 const auto startShapeRank =
2461 llvm::cast<tosa::shapeType>(start.getType()).getRank();
2462 if (inputRank != startShapeRank)
2463 return emitOpError("length of start is not equal to rank of input shape");
2464
2465 const auto sizeShapeRank =
2466 llvm::cast<tosa::shapeType>(size.getType()).getRank();
2467 if (inputRank != sizeShapeRank)
2468 return emitOpError("length of size is not equal to rank of input shape");
2469 }
2470
2471 SmallVector<int64_t> startValues;
2472 tosa::getConstShapeValues(start.getDefiningOp(), startValues);
2473 if (startValues.size()) {
2474 if (llvm::any_of(startValues, [](const int64_t v) {
2475 return v < 0 && v != kInferableDimSize;
2476 }))
2477 return emitOpError("start values must be non-negative, got [")
2478 << startValues << "]";
2479 }
2480
2481 SmallVector<int64_t> sizeValues;
2482 if (!tosa::getConstShapeValues(size.getDefiningOp(), sizeValues))
2483 return success();
2484
2485 if (llvm::any_of(sizeValues, [](const int64_t v) {
2486 return v <= 0 && v != kInferableDimSize;
2487 }))
2488 return emitOpError("size values must be > 0, got [") << sizeValues << "]";
2489 if (outputShape.hasRank()) {
2490 SmallVector<int64_t> outputDims;
2491 outputShape.getDims(outputDims);
2492 const bool hasNoInferableDims = llvm::all_of(
2493 sizeValues, [](const int64_t v) { return v != kInferableDimSize; });
2494 if (hasNoInferableDims &&
2495 failed(verifyCompatibleShape(outputDims, sizeValues)))
2496 return emitOpError("expected output shape to match size values, got ")
2497 << output.getType() << " vs [" << sizeValues << "]";
2498 }
2499
2500 if (inputShape.hasRank() && startValues.size()) {
2501 SmallVector<int64_t> inputDims;
2502 inputShape.getDims(inputDims);
2503 for (const auto &[index, vals] :
2504 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2505 const auto &[start, size, inputDim] = vals;
2506 if (start == kInferableDimSize || size == kInferableDimSize ||
2507 ShapedType::isDynamic(inputDim))
2508 continue;
2509 if (start + size > inputDim)
2510 return emitOpError("start + size must be less than or equal to input "
2511 "dimension size, got start=")
2512 << start << ", size=" << size
2513 << " vs input dim size=" << inputDim << " at dimension "
2514 << index;
2515 }
2516 }
2517
2518 return success();
2519}
2520
2521LogicalResult tosa::MulOp::inferReturnTypeComponents(
2522 MLIRContext *context, ::std::optional<Location> location,
2523 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
2524 RegionRange regions,
2525 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2526 // mul op's output shape only depend on input1 and input2, not on shift
2527 ValueShapeRange twoInputs = operands.drop_back();
2529 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2530 inferredReturnShapes.push_back(ShapedTypeComponents());
2531 } else {
2532 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2533 }
2534 return success();
2535}
2536
2537LogicalResult tosa::MulOp::verify() {
2538 const Value output = getOutput();
2539 auto resElemType = getElementTypeOrSelf(output);
2540
2541 // Verify if the element type among operands and result match tosa
2542 // specification.
2543 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2544 IntegerType lhsIntType =
2545 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2546 IntegerType rhsIntType =
2547 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2548 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2549 return emitOpError("requires the same element type for all operands");
2550
2551 // Though the spec requires the element type of result to be i32, a more
2552 // relaxed way is provided at dialect level for easier cooperating with
2553 // other dialects.
2554 if (lhsIntType.getWidth() > resIntType.getWidth())
2555 return emitOpError("invalid data type size for operands or result");
2556
2557 } else {
2558 // For other supported type, the spec requires requires the same element
2559 // type for all operands (excludes `shift` operand) and results.
2560 for (int i = 0; i < 2; ++i) {
2561 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2562 return emitOpError(
2563 "requires the same element type for all operands and results");
2564 }
2565
2566 // verify shift has value 0 for non-integer types
2567 ElementsAttr shiftElem;
2568 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2569 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2570 if (shift != 0) {
2571 return emitOpError() << "require shift to be 0 for float type";
2572 }
2573 }
2574 }
2575
2576 // Verify the op has same ranks for all main operands (excludes extra operands
2577 // such as shift of mul op, so this is the only difference with the built-in
2578 // `SameOperandsAndResultRank` trait) and results types, if known.
2579 TypeRange operandTypes = getOperandTypes();
2580 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2581 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2582
2583 const bool aHasRank = aType.hasRank();
2584 const bool bHasRank = bType.hasRank();
2585
2586 bool hasExpectedOutputShape = false;
2587 SmallVector<int64_t> expectedOutputShape;
2588
2589 if (aHasRank && bHasRank) {
2590 const int64_t aRank = aType.getRank();
2591 const int64_t bRank = bType.getRank();
2592 if (aRank != bRank)
2593 return emitOpError("a and b operands don't have matching ranks, got ")
2594 << aRank << " and " << bRank;
2595
2596 // check for broadcast compatible shapes
2598 aType.getShape(), bType.getShape(), expectedOutputShape))
2599 return emitOpError("a and b operands don't have broadcast-compatible "
2600 "shapes, got ")
2601 << aType << " and " << bType;
2602 hasExpectedOutputShape = true;
2603 }
2604
2605 ShapedType resultType = cast<ShapedType>(output.getType());
2606 if (!resultType.hasRank())
2607 return success();
2608
2609 const int64_t resultRank = resultType.getRank();
2610 if (aHasRank && resultRank != aType.getRank())
2611 return emitOpError("result type has different rank than a, got ")
2612 << resultRank << " vs " << aType.getRank();
2613 if (bHasRank && resultRank != bType.getRank())
2614 return emitOpError("result type has different rank than b, got ")
2615 << resultRank << " vs " << bType.getRank();
2616
2617 if (hasExpectedOutputShape &&
2618 failed(verifyOutputShapeCompatibleWithExpected(getOperation(), resultType,
2619 expectedOutputShape)))
2620 return failure();
2621
2622 return success();
2623}
2624
2625LogicalResult tosa::TableOp::inferReturnTypeComponents(
2626 MLIRContext *context, ::std::optional<Location> location,
2627 TableOp::Adaptor adaptor,
2628 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2629 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2630
2631 if (!inputShape.hasRank()) {
2632 inferredReturnShapes.push_back(ShapedTypeComponents());
2633 return success();
2634 }
2635
2636 inferredReturnShapes.resize(1);
2637 inputShape.getDims(inferredReturnShapes[0]);
2638 return success();
2639}
2640
2641LogicalResult tosa::TableOp::verify() {
2642 const TensorType inputType = getInput1().getType();
2643 const TensorType outputType = getOutput().getType();
2644
2645 if (!inputType.hasRank() || !outputType.hasRank())
2646 return success();
2647
2648 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2649 "result")))
2650 return failure();
2651
2652 auto inputDims = inputType.getShape();
2653 auto outputDims = outputType.getShape();
2654 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2655 int64_t dim = it.index();
2656 auto [inputDim, outputDim] = it.value();
2657 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2658 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2659 << " doesn't match dim(input, " << dim
2660 << ") = " << inputDim;
2661 }
2662 }
2663 return success();
2664}
2665
2666LogicalResult
2667tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2668 // Multiples must be constants.
2669 DenseIntElementsAttr multiplesAttr;
2670 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2671 return failure();
2672 multiples =
2673 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2674 [](const APInt &val) { return val.getSExtValue(); });
2675 return success();
2676}
2677
2678LogicalResult tosa::TileOp::inferReturnTypeComponents(
2679 MLIRContext *context, ::std::optional<Location> location,
2680 TileOp::Adaptor adaptor,
2681 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2682 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2683 SmallVector<int64_t> multiples;
2684 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2685 multiples)) {
2686 auto rank =
2687 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2688 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2689 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2690 return success();
2691 }
2692 multiples = convertToMlirShape(multiples);
2693
2694 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2695 SmallVector<int64_t> outputShape;
2696 if (!inputShape.hasRank()) {
2697 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2698 inferredReturnShapes.push_back(
2699 ShapedTypeComponents(outputShape, inputType));
2700 return success();
2701 }
2702 if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2703 return failure();
2704
2705 // Any non dynamic dimension can be multiplied to a known size.
2706 outputShape.reserve(multiples.size());
2707 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2708 if (multiples[i] == ShapedType::kDynamic) {
2709 outputShape.push_back(ShapedType::kDynamic);
2710 } else {
2711 int64_t dim = inputShape.getDimSize(i);
2712 if (dim != ShapedType::kDynamic)
2713 dim *= multiples[i];
2714 outputShape.push_back(dim);
2715 }
2716 }
2717
2718 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2719 return success();
2720}
2721
2722LogicalResult tosa::TileOp::verify() {
2723 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2724 /* outType = */ getOutput().getType())
2725 .failed()) {
2726 return failure();
2727 }
2728 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2729 ShapedType outputType = llvm::cast<ShapedType>(getType());
2730
2731 shapeType multiplesType =
2732 llvm::cast<tosa::shapeType>(getMultiples().getType());
2733
2734 auto multiplesRank = multiplesType.getRank();
2735
2736 if (inputType.hasRank()) {
2737 if (inputType.getRank() != multiplesRank)
2738 return emitOpError("expect 'multiples' to have rank ")
2739 << inputType.getRank() << " but got " << multiplesRank << ".";
2740 if (outputType.hasRank() &&
2741 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2742 "output")))
2743 return failure();
2744 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2745 return emitOpError("expect 'multiples' array to have length ")
2746 << outputType.getRank() << " but got " << multiplesRank << ".";
2747
2748 SmallVector<int64_t> multiples;
2749 if (getConstantMultiples(multiples).succeeded() &&
2750 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2751 return emitOpError(
2752 "expect element of 'multiples' to be positive integer or -1.");
2753
2754 return success();
2755}
2756
2757bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2758 if (l.size() != r.size() || l.size() != 1)
2759 return false;
2760 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2761}
2762
2763LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2764 MLIRContext *context, ::std::optional<Location> location,
2765 ReshapeOp::Adaptor adaptor,
2766 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2767 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2768 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2769 llvm::SmallVector<int64_t> newShapeValue;
2770 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2771 newShapeValue)) {
2772 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2773 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2774 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2775 return success();
2776 }
2777 newShapeValue = convertToMlirShape(newShapeValue);
2778
2779 // We cannot infer from the total number of elements so we must take the
2780 // shape attribute as exact.
2781 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2782 inferredReturnShapes.push_back(
2783 ShapedTypeComponents(newShapeValue, inputType));
2784 return success();
2785 }
2786
2787 // Determine the number of elements covered by the slice of all static
2788 // dimensions. This allows us to infer the length of the remaining dynamic
2789 // dimension.
2790 int64_t numElements = inputShape.getNumElements();
2791 int64_t staticMul = 1;
2792 for (auto val : newShapeValue) {
2793 if (ShapedType::isStatic(val)) {
2794 staticMul *= val;
2795 }
2796 }
2797
2798 // Determine the length of the dynamic dimension.
2799 for (auto &val : newShapeValue) {
2800 if (ShapedType::isDynamic(val))
2801 val = numElements / staticMul;
2802 }
2803
2804 inferredReturnShapes.push_back(
2805 ShapedTypeComponents(newShapeValue, inputType));
2806 return success();
2807}
2808
2809llvm::LogicalResult tosa::ReshapeOp::verify() {
2810 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2811 /* outType = */ getOutput().getType())
2812 .failed()) {
2813 return failure();
2814 }
2815 TensorType inputType = getInput1().getType();
2816
2817 SmallVector<int64_t> shapeValues;
2818 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2819 // skip following checks if shape is not constant
2820 return mlir::success();
2821 }
2822
2823 int missingDims = llvm::count(shapeValues, kInferableDimSize);
2824 if (missingDims > 1)
2825 return emitOpError() << "expected at most one target dimension to be "
2827
2828 const auto outputType = dyn_cast<RankedTensorType>(getType());
2829 if (!outputType)
2830 return success();
2831
2832 if ((int64_t)shapeValues.size() != outputType.getRank())
2833 return emitOpError() << "new shape does not match result rank";
2834
2835 for (auto [newShapeDim, outputShapeDim] :
2836 zip(shapeValues, outputType.getShape())) {
2837 if (newShapeDim != kInferableDimSize &&
2838 newShapeDim != ShapedType::kDynamic &&
2839 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2840 return emitOpError() << "new shape is inconsistent with result shape";
2841
2842 if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
2843 return emitOpError() << "new shape has invalid tensor dimension size "
2844 << newShapeDim;
2845 }
2846
2847 if (inputType.hasStaticShape()) {
2848 int64_t inputElementsNum = inputType.getNumElements();
2849 if (outputType.hasStaticShape()) {
2850 int64_t outputElementsNum = outputType.getNumElements();
2851 if (inputElementsNum != outputElementsNum) {
2852 return emitOpError() << "cannot reshape " << inputElementsNum
2853 << " elements into " << outputElementsNum;
2854 }
2855 }
2856
2857 int64_t newShapeElementsNum =
2858 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2859 return (dim > 0) ? acc * dim : acc;
2860 });
2861 bool isStaticNewShape =
2862 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2863 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2864 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2865 return emitOpError() << "cannot reshape " << inputElementsNum
2866 << " elements into " << newShapeElementsNum;
2867 }
2868 }
2869
2870 return mlir::success();
2871}
2872
2873bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(TypeRange l,
2874 TypeRange r) {
2875 if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
2876 return false;
2877 bool ok = (getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]));
2878 if (l.size() == 2)
2879 ok = ok && (getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]));
2880 return ok;
2881}
2882
2883LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
2884 MLIRContext *context, ::std::optional<Location> location,
2885 ReshapeBlockScaledOp::Adaptor adaptor,
2886 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2887
2888 const auto numInputs = adaptor.getInput().size();
2889 ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
2890 Type inputType = getElementTypeOrSelf(adaptor.getInput()[0].getType());
2891 llvm::SmallVector<int64_t> newShapeValue;
2892 const auto newShape = adaptor.getNewValueShape();
2893 if (!tosa::getConstShapeValues(newShape.getDefiningOp(), newShapeValue)) {
2894 auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
2895 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2896 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2897 if (numInputs == 2)
2898 inferredReturnShapes.push_back(ShapedTypeComponents(
2899 fallback, getElementTypeOrSelf(adaptor.getInput()[1].getType())));
2900 return success();
2901 }
2902
2903 const uint32_t blockSize =
2904 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
2905
2906 llvm::SmallVector<int64_t> newScaleShapeValue;
2907 if (numInputs == 2) {
2908 newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
2909 if (ShapedType::isStatic(newScaleShapeValue.back()))
2910 newScaleShapeValue.back() /= blockSize;
2911 }
2912
2913 inferredReturnShapes.push_back(
2914 ShapedTypeComponents(newShapeValue, inputType));
2915 if (numInputs == 2) {
2916 // Fix up scale shape - with special case for last dimension
2917 for (size_t idx = 0; idx < newShapeValue.size(); idx++) {
2918 if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
2919 newScaleShapeValue[idx] = newShapeValue[idx];
2920 if (idx == (newShapeValue.size() - 1))
2921 newScaleShapeValue[idx] /= blockSize;
2922 }
2923 }
2924
2925 inferredReturnShapes.push_back(ShapedTypeComponents(
2926 newScaleShapeValue,
2927 getElementTypeOrSelf(adaptor.getInput()[1].getType())));
2928 }
2929 return success();
2930}
2931
2932llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
2933 const Operation::operand_range inputList = getInput();
2934 const Operation::result_range outputList = getResults();
2935
2936 if (inputList.size() == 0)
2937 return emitOpError("requires at least one input");
2938
2939 if (inputList.size() > 2)
2940 return emitOpError("requires at most two inputs");
2941
2942 if (inputList.size() != outputList.size())
2943 return emitOpError("requires number of results to match inputs");
2944
2945 if (verifySameElementTypes(*this, /* inType = */ inputList[0].getType(),
2946 /* outType = */ outputList[0].getType())
2947 .failed()) {
2948 return failure();
2949 }
2950
2951 const auto inputType = llvm::cast<ShapedType>(inputList[0].getType());
2952 if (!inputType.hasRank())
2953 return success();
2954 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2955
2956 if (inputList.size() == 2) {
2957 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2958 return emitOpError("expect block size to be 32, got ") << blockSize;
2959 if (llvm::any_of(inputList, [](Value v) {
2960 const auto input = cast<ShapedType>(v.getType());
2961 return input.hasRank() && input.getRank() == 0;
2962 }))
2963 return emitOpError(
2964 "requires all input shapes have a rank greater than 0");
2965 if (llvm::any_of(outputList, [](Value v) {
2966 const auto output = cast<ShapedType>(v.getType());
2967 return output.hasRank() && output.getRank() == 0;
2968 }))
2969 return emitOpError(
2970 "requires all result shapes have a rank greater than 0");
2971
2972 if (verifySameElementTypes(*this, /* inType = */ inputList[1].getType(),
2973 /* outType = */ outputList[1].getType())
2974 .failed()) {
2975 return failure();
2976 }
2977
2978 const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].getType());
2979 if (inputScaleType.hasRank()) {
2980 if (inputType.getRank() != inputScaleType.getRank())
2981 return emitOpError("input shapes do not have same rank");
2982
2983 // Check all but the last dimension that the input shape dimensions match
2984 for (auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
2985 const int64_t inputValueDim = inputType.getDimSize(dimIdx);
2986 const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
2987 if (ShapedType::isStatic(inputValueDim) &&
2988 ShapedType::isStatic(inputScaleDim) &&
2989 inputValueDim != inputScaleDim)
2990 return emitOpError("input shapes for data and scale do not match on "
2991 "dimension ")
2992 << dimIdx;
2993 }
2994
2995 // Verify last dimension of input is a multiple of block size
2996 const int64_t lastValueDim =
2997 inputType.getDimSize(inputType.getRank() - 1);
2998 if (ShapedType::isStatic(lastValueDim)) {
2999 if (lastValueDim % blockSize != 0)
3000 return emitOpError("expect last dimension of input_data (")
3001 << lastValueDim << ") to be divisible by block_size ("
3002 << blockSize << ")";
3003
3004 const int64_t lastScaleDim =
3005 inputScaleType.getDimSize(inputScaleType.getRank() - 1);
3006 // Verify last dimension of scale is lastValueDim / block size
3007 if (ShapedType::isStatic(lastScaleDim) &&
3008 lastScaleDim != lastValueDim / blockSize)
3009 return emitOpError("expect last dimension of scale_data (")
3010 << lastScaleDim << ") to be " << lastValueDim << "/"
3011 << blockSize;
3012 }
3013 }
3014 } else {
3015 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
3016 return emitOpError("expect block size to be 1, got ") << blockSize;
3017 }
3018
3019 // Get the new value shape dimension values
3020 SmallVector<int64_t> shapeValues;
3021 if (!tosa::getConstShapeValues(getNewValueShape().getDefiningOp(),
3022 shapeValues)) {
3023 // skip following checks if shape is not constant
3024 return mlir::success();
3025 }
3026
3027 if (inputList.size() == 2) {
3028 if (static_cast<int64_t>(shapeValues.size()) == 0)
3029 return emitOpError("requires new shape to have a rank greater than 0");
3030
3031 const int64_t lastShapeDim = shapeValues.back();
3032 if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
3033 return emitOpError("expect last dimension of new shape (")
3034 << lastShapeDim << ") to be divisible by block_size (" << blockSize
3035 << ")";
3036 }
3037
3038 const auto outputType = llvm::cast<ShapedType>(outputList[0].getType());
3039 if (!outputType.hasRank())
3040 return success();
3041
3042 if (static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
3043 return emitOpError() << "result does not match new shape rank";
3044
3045 for (auto [newShapeDim, outputShapeDim] :
3046 zip(shapeValues, outputType.getShape())) {
3047 if (ShapedType::isStatic(newShapeDim) &&
3048 ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
3049 return emitOpError() << "result shape is inconsistent with new shape";
3050 }
3051
3052 if (outputList.size() == 2) {
3053 // Set up scale shape from new shape given
3054 SmallVector<int64_t> scaleShapeValues(shapeValues.begin(),
3055 shapeValues.end());
3056 scaleShapeValues.back() /= blockSize;
3057
3058 const auto outputScaleType =
3059 llvm::cast<ShapedType>(outputList[1].getType());
3060 if (outputScaleType.hasRank()) {
3061 if ((int64_t)scaleShapeValues.size() != outputScaleType.getRank())
3062 return emitOpError() << "result scale does not match new shape rank";
3063
3064 for (auto [newScaleShapeDim, outputScaleShapeDim] :
3065 zip(scaleShapeValues, outputScaleType.getShape())) {
3066 if (ShapedType::isStatic(newScaleShapeDim) &&
3067 ShapedType::isStatic(outputScaleShapeDim) &&
3068 newScaleShapeDim != outputScaleShapeDim)
3069 return emitOpError()
3070 << "result scale shape is inconsistent with new shape";
3071 }
3072 }
3073 }
3074
3075 if (inputType.hasStaticShape()) {
3076 int64_t inputElementsNum = inputType.getNumElements();
3077 if (outputType.hasStaticShape()) {
3078 int64_t outputElementsNum = outputType.getNumElements();
3079 if (inputElementsNum != outputElementsNum) {
3080 return emitOpError() << "cannot reshape " << inputElementsNum
3081 << " elements into " << outputElementsNum;
3082 }
3083 }
3084
3085 int64_t newShapeElementsNum =
3086 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
3087 return (dim > 0) ? acc * dim : acc;
3088 });
3089 bool isStaticNewShape =
3090 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
3091 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
3092 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
3093 return emitOpError() << "cannot reshape " << inputElementsNum
3094 << " elements into " << newShapeElementsNum;
3095 }
3096 }
3097
3098 return mlir::success();
3099}
3100
3101// return failure if val is not a constant
3102// set zp to -1 if val is non-zero float or val is not integer nor float
3103// otherwise set zp to val's constant value
3104static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
3105 ElementsAttr zpAttr;
3106 if (!matchPattern(val, m_Constant(&zpAttr))) {
3107 return failure();
3108 }
3109
3110 Type zpElemType = zpAttr.getElementType();
3111
3112 if (llvm::isa<FloatType>(zpElemType)) {
3113 if (zpAttr.getValues<APFloat>()[0].isZero()) {
3114 return 0;
3115 }
3116 // return non-zero value to trigger error check
3117 return -1;
3118 }
3119
3120 if (llvm::isa<IntegerType>(zpElemType)) {
3121 if (signExtend)
3122 return zpAttr.getValues<APInt>()[0].getSExtValue();
3123 return zpAttr.getValues<APInt>()[0].getZExtValue();
3124 }
3125
3126 // return non-zero value to trigger error check
3127 return -1;
3128}
3129
3130template <typename T>
3131static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
3132 const std::string &operand) {
3133 Type zpElemType = getElementTypeOrSelf(val);
3134
3135 if (!zpElemType.isInteger(8) && zp != 0) {
3136 // convert operand to lower case for error message
3137 std::string lower = operand;
3138 llvm::transform(lower, lower.begin(), ::tolower);
3139 return op.emitOpError()
3140 << lower << " zero point must be zero for non-int8 integer types";
3141 }
3142
3143 return success();
3144}
3145
3146static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
3147 const int64_t &zp,
3148 const std::string &operand) {
3149 bool isInputZp = (operand == "Input");
3150
3151 bool tensorUnsigned =
3152 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
3153 StringRef tensorName = isInputZp ? "input" : "output";
3154
3155 Type zpElemType = getElementTypeOrSelf(zpVal);
3156
3157 if (zp != 0) {
3158 if (!zpElemType.isInteger(8) &&
3159 !(zpElemType.isInteger(16) && tensorUnsigned)) {
3160 return op.emitOpError()
3161 << "expect " << tensorName << "_zp of 0, got " << zp;
3162 }
3163 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
3164 return op.emitOpError() << "expect " << tensorName
3165 << "_zp of 0 or 32768 for unsigned int16 "
3166 << tensorName << ", got " << zp;
3167 }
3168 }
3169
3170 return success();
3171}
3172
3173#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
3174 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
3175 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
3176 } \
3177 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
3178 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
3179 }
3180
3181ZERO_POINT_HELPER(Conv2DOp, Input, true)
3182ZERO_POINT_HELPER(Conv2DOp, Weight, true)
3183ZERO_POINT_HELPER(Conv3DOp, Input, true)
3184ZERO_POINT_HELPER(Conv3DOp, Weight, true)
3185ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
3186ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
3187ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
3188ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
3189ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
3190ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
3191ZERO_POINT_HELPER(AvgPool2dAdaptiveOp, Input, true)
3192ZERO_POINT_HELPER(AvgPool2dAdaptiveOp, Output, true)
3193ZERO_POINT_HELPER(MatMulOp, A, true)
3194ZERO_POINT_HELPER(MatMulOp, B, true)
3195ZERO_POINT_HELPER(NegateOp, Input1, true)
3196ZERO_POINT_HELPER(NegateOp, Output, true)
3197ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
3198ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
3199#undef ZERO_POINT_HELPER
3200
3201LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
3202 MLIRContext *context, ::std::optional<Location> location,
3203 TransposeOp::Adaptor adaptor,
3204 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3205 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3206
3207 // If input rank and permutation length is unknown, the output rank is
3208 // unknown.
3209 if (!inputShape.hasRank()) {
3210 inferredReturnShapes.push_back(ShapedTypeComponents());
3211 return success();
3212 }
3213
3214 const auto inputRank = inputShape.getRank();
3215
3216 // This would imply the number of permutations does not match the rank of
3217 // the input which is illegal.
3218 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
3219 return failure();
3220 }
3221
3222 SmallVector<int64_t> outputShape;
3223 // Rank-0 means no permutations matter.
3224 if (inputRank == 0) {
3225 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3226 return success();
3227 }
3228
3229 // Check whether the input dimensions are all the same.
3230 bool allTheSame = true;
3231 for (int i = 1, s = inputRank; i < s; i++) {
3232 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
3233 allTheSame = false;
3234 break;
3235 }
3236 }
3237
3238 // If all of the input dimensions are the same we don't care about the
3239 // permutation.
3240 if (allTheSame) {
3241 outputShape.resize(inputRank, inputShape.getDimSize(0));
3242 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3243 return success();
3244 }
3245
3246 outputShape.resize(inputRank, ShapedType::kDynamic);
3247
3248 // Constant permutation values must be within the input rank.
3249 if (llvm::any_of(adaptor.getPerms(),
3250 [inputRank](const auto i) { return i >= inputRank; }))
3251 return failure();
3252
3253 outputShape.reserve(inputRank);
3254 for (int i = 0, s = inputRank; i < s; i++) {
3255 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
3256 }
3257
3258 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3259 return success();
3260}
3261
3262LogicalResult tosa::TransposeOp::verify() {
3263 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3264 /* outType = */ getOutput().getType())
3265 .failed()) {
3266 return failure();
3267 }
3268
3269 const ShapeAdaptor inputShape(getInput1().getType());
3270 const ShapeAdaptor outputShape(getOutput().getType());
3271
3272 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
3273
3274 if (inputShape.hasRank() &&
3275 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
3276 return emitOpError() << "expected perms attribute to have size "
3277 << inputShape.getRank()
3278 << " (input rank) but got size "
3279 << constantPerms.size();
3280
3281 if (inputShape.hasRank() && outputShape.hasRank() &&
3282 inputShape.getRank() != outputShape.getRank())
3283 return emitOpError()
3284 << "expected input tensor rank to equal result tensor rank";
3285
3286 if (outputShape.hasRank() &&
3287 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
3288 return emitOpError() << "expected perms attribute to have size "
3289 << outputShape.getRank()
3290 << " (output rank) but got size "
3291 << constantPerms.size();
3292
3293 if (!llvm::all_of(constantPerms,
3294 [&constantPerms](int32_t s) {
3295 return s >= 0 &&
3296 static_cast<size_t>(s) < constantPerms.size();
3297 }) ||
3298 !isPermutationVector(llvm::map_to_vector(
3299 constantPerms, [](int32_t v) -> int64_t { return v; })))
3300 return emitOpError() << "expected valid permutation indices";
3301
3302 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
3303 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
3304 inputShape.getNumElements() != outputShape.getNumElements())
3305 return emitOpError() << "expected input1 and output to have same numbers "
3306 "of elements, got "
3307 << inputShape.getNumElements() << " and "
3308 << outputShape.getNumElements();
3309
3310 // Verify that the types of the input and output tensors are properly
3311 // permuted.
3312 if (inputShape.hasRank() && outputShape.hasRank()) {
3313 for (auto i = 0; i < outputShape.getRank(); i++) {
3314 if (inputShape.isDynamicDim(constantPerms[i]) ||
3315 outputShape.isDynamicDim(i))
3316 continue;
3317
3318 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
3319 return emitOpError()
3320 << "expected output tensor dim " << i << " to match "
3321 << "input dim " << constantPerms[i] << " with value of "
3322 << inputShape.getDimSize(constantPerms[i]);
3323 }
3324 }
3325
3326 return success();
3327}
3328
3329LogicalResult TransposeOp::reifyResultShapes(
3330 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3331
3332 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
3333
3334 Value input = getInput1();
3335 auto inputType = cast<TensorType>(input.getType());
3336
3337 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
3338 for (auto dim : transposePerms) {
3339 int32_t dimInInput = transposePerms[dim];
3340 if (inputType.isDynamicDim(dimInInput))
3341 returnedDims[dim] =
3342 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
3343 .getResult();
3344 else
3345 returnedDims[dim] =
3346 builder.getIndexAttr(inputType.getDimSize(dimInInput));
3347 }
3348
3349 reifiedReturnShapes.emplace_back(std::move(returnedDims));
3350 return success();
3351}
3352
3353LogicalResult tosa::GatherOp::inferReturnTypeComponents(
3354 MLIRContext *context, ::std::optional<Location> location,
3355 GatherOp::Adaptor adaptor,
3356 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3357 llvm::SmallVector<int64_t> outputShape;
3358 outputShape.resize(3, ShapedType::kDynamic);
3359
3360 ShapeAdaptor valuesShape(adaptor.getValues().getType());
3361 if (valuesShape.hasRank()) {
3362 outputShape[0] = valuesShape.getDimSize(0);
3363 outputShape[2] = valuesShape.getDimSize(2);
3364 }
3365
3366 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3367 if (indicesShape.hasRank()) {
3368 if (outputShape[0] == ShapedType::kDynamic)
3369 outputShape[0] = indicesShape.getDimSize(0);
3370 if (outputShape[1] == ShapedType::kDynamic)
3371 outputShape[1] = indicesShape.getDimSize(1);
3372 }
3373
3374 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3375 return success();
3376}
3377
3378LogicalResult tosa::RowGatherOp::inferReturnTypeComponents(
3379 MLIRContext *context, ::std::optional<Location> location,
3380 RowGatherOp::Adaptor adaptor,
3381 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3382 llvm::SmallVector<int64_t> outputShape;
3383 outputShape.resize(3, ShapedType::kDynamic);
3384
3385 const ShapeAdaptor valuesShape(adaptor.getValues().getType());
3386 if (valuesShape.hasRank()) {
3387 outputShape[0] = valuesShape.getDimSize(0);
3388 outputShape[2] = valuesShape.getDimSize(2);
3389 }
3390
3391 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3392 if (indicesShape.hasRank()) {
3393 if (outputShape[0] == ShapedType::kDynamic)
3394 outputShape[0] = indicesShape.getDimSize(0);
3395
3396 const FailureOr<int32_t> maybeRowCount =
3397 getConstantScalarIntValue<int32_t>(adaptor.getRowCount());
3398 if (succeeded(maybeRowCount)) {
3399 const int64_t indicesW = indicesShape.getDimSize(1);
3400 if (ShapedType::isStatic(indicesW))
3401 outputShape[1] = indicesW * maybeRowCount.value();
3402 }
3403 }
3404
3405 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3406 return success();
3407}
3408
3409LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
3410 MLIRContext *context, ::std::optional<Location> location,
3411 RowGatherBlockScaledOp::Adaptor adaptor,
3412 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3413 const auto values = adaptor.getValues();
3414 if (values.empty())
3415 return failure();
3416
3417 SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
3418 const ShapeAdaptor valuesShape(values.front().getType());
3419 if (valuesShape.hasRank()) {
3420 dataShape[0] = valuesShape.getDimSize(0);
3421 dataShape[2] = valuesShape.getDimSize(2);
3422 }
3423
3424 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3425 if (indicesShape.hasRank()) {
3426 if (dataShape[0] == ShapedType::kDynamic)
3427 dataShape[0] = indicesShape.getDimSize(0);
3428
3429 if (auto rowCount =
3430 getConstantScalarIntValue<int32_t>(adaptor.getRowCount());
3431 succeeded(rowCount) && rowCount.value() > 0) {
3432 const int64_t indicesW = indicesShape.getDimSize(1);
3433 if (ShapedType::isStatic(indicesW))
3434 dataShape[1] = indicesW * rowCount.value();
3435 }
3436 }
3437
3438 inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
3439 if (values.size() == 1)
3440 return success();
3441
3442 SmallVector<int64_t> scaleShape = dataShape;
3443 const uint32_t blockSize =
3444 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
3445 if (ShapedType::isStatic(dataShape[2]))
3446 scaleShape[2] = dataShape[2] / blockSize;
3447
3448 inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
3449 return success();
3450}
3451
3452LogicalResult tosa::GatherOp::verify() {
3453 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
3454 /* outType = */ getOutput().getType())
3455 .failed()) {
3456 return failure();
3457 }
3458
3459 const ShapeAdaptor valuesShape(getValues().getType());
3460 const ShapeAdaptor indicesShape(getIndices().getType());
3461 const ShapeAdaptor outputShape(getOutput().getType());
3462
3463 int64_t n = ShapedType::kDynamic;
3464 int64_t w = ShapedType::kDynamic;
3465 int64_t c = ShapedType::kDynamic;
3466
3467 if (valuesShape.hasRank()) {
3468 n = valuesShape.getDimSize(0);
3469 c = valuesShape.getDimSize(2);
3470 }
3471 if (indicesShape.hasRank()) {
3472 const int64_t indicesN = indicesShape.getDimSize(0);
3473 w = indicesShape.getDimSize(1);
3474 if (n == ShapedType::kDynamic)
3475 n = indicesN;
3476 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3477 return emitOpError() << "requires indices dimension 0 to have size " << n
3478 << ", got " << indicesN;
3479 }
3480 if (outputShape.hasRank()) {
3481 const int64_t outputN = outputShape.getDimSize(0);
3482 const int64_t outputW = outputShape.getDimSize(1);
3483 const int64_t outputC = outputShape.getDimSize(2);
3484 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3485 n != outputN)
3486 return emitOpError() << "requires output dimension 0 to have size " << n
3487 << ", got " << outputN;
3488
3489 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
3490 w != outputW)
3491 return emitOpError() << "requires output dimension 1 to have size " << w
3492 << ", got " << outputW;
3493 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3494 c != outputC)
3495 return emitOpError() << "requires output dimension 2 to have size " << c
3496 << ", got " << outputC;
3497 }
3498 return success();
3499}
3500
3501LogicalResult tosa::RowGatherOp::verify() {
3502 if (failed(verifySameElementTypes(*this, /* inType = */ getValues().getType(),
3503 /* outType = */ getOutput().getType())))
3504 return failure();
3505
3506 const FailureOr<int32_t> maybeRowCount =
3508 if (succeeded(maybeRowCount) && maybeRowCount.value() <= 0)
3509 return emitOpError() << "requires row_count to be > 0, got "
3510 << maybeRowCount.value();
3511
3512 int64_t n = ShapedType::kDynamic;
3513 int64_t c = ShapedType::kDynamic;
3514 int64_t w = ShapedType::kDynamic;
3515
3516 const ShapeAdaptor valuesShape(getValues().getType());
3517 if (valuesShape.hasRank()) {
3518 n = valuesShape.getDimSize(0);
3519 c = valuesShape.getDimSize(2);
3520 }
3521
3522 const ShapeAdaptor indicesShape(getIndices().getType());
3523 if (indicesShape.hasRank()) {
3524 if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
3525 "indices", "batch")))
3526 return failure();
3527 w = indicesShape.getDimSize(1);
3528 }
3529
3530 const ShapeAdaptor outputShape(getOutput().getType());
3531 if (outputShape.hasRank()) {
3532 if (failed(tryUpdateDimOrFailure(*this, n, outputShape.getDimSize(0),
3533 "output", "batch")) ||
3534 failed(tryUpdateDimOrFailure(*this, c, outputShape.getDimSize(2),
3535 "output", "channels")))
3536 return failure();
3537
3538 if (succeeded(maybeRowCount) && maybeRowCount.value() > 0 &&
3539 ShapedType::isStatic(w)) {
3540 const int64_t expectedOutputRows = w * maybeRowCount.value();
3541 if (ShapedType::isStatic(outputShape.getDimSize(1)) &&
3542 outputShape.getDimSize(1) != expectedOutputRows)
3543 return emitOpError()
3544 << "requires output dimension to be equal to "
3545 "indices[1]*row_count ("
3546 << expectedOutputRows << "), got " << outputShape.getDimSize(1);
3547 }
3548 }
3549
3550 return success();
3551}
3552
3553LogicalResult tosa::RowGatherBlockScaledOp::verify() {
3554 const OperandRange values = getValues();
3555 const ResultRange output = getOutput();
3556 if (values.empty() || values.size() > 2)
3557 return emitOpError()
3558 << "expects values tensor list length to be 1 or 2, got "
3559 << values.size();
3560 if (output.size() != values.size())
3561 return emitOpError()
3562 << "expects output tensor list length to match values tensor list "
3563 "length, got "
3564 << output.size() << " results for " << values.size()
3565 << " input tensors";
3566
3567 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3568 if (values.size() == 1 && blockSize != 1)
3569 return emitOpError()
3570 << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
3571 "length is 1";
3572 if (values.size() == 2 && blockSize == 1)
3573 return emitOpError()
3574 << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
3575 "list length is 2";
3576
3577 if (failed(verifySameElementTypes(*this, values[0].getType(),
3578 output[0].getType(), "values[0]",
3579 "output[0]")))
3580 return failure();
3581 if (values.size() == 2 && failed(verifySameElementTypes(
3582 *this, values[1].getType(), output[1].getType(),
3583 "values[1]", "output[1]")))
3584 return failure();
3585
3586 if (auto rowCount = getConstantScalarIntValue<int32_t>(getRowCount());
3587 succeeded(rowCount) && rowCount.value() <= 0)
3588 return emitOpError() << "requires row_count to be > 0, got "
3589 << rowCount.value();
3590
3591 int64_t n = ShapedType::kDynamic;
3592 int64_t k = ShapedType::kDynamic;
3593 int64_t c = ShapedType::kDynamic;
3594 int64_t w = ShapedType::kDynamic;
3595 int64_t multiplesOfC = ShapedType::kDynamic;
3596
3597 const ShapeAdaptor valuesDataShape(values[0].getType());
3598 if (valuesDataShape.hasRank()) {
3599 n = valuesDataShape.getDimSize(0);
3600 k = valuesDataShape.getDimSize(1);
3601 c = valuesDataShape.getDimSize(2);
3602 }
3603
3604 if (ShapedType::isStatic(c) && c % blockSize != 0)
3605 return emitOpError() << "expects channels of values[0] (" << c
3606 << ") to be divisible by block_size (" << blockSize
3607 << ")";
3608
3609 const ShapeAdaptor indicesShape(getIndices().getType());
3610 if (indicesShape.hasRank()) {
3611 if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
3612 "indices", "batch")))
3613 return failure();
3614 w = indicesShape.getDimSize(1);
3615 }
3616
3617 const ShapeAdaptor outputDataShape(output[0].getType());
3618 if (outputDataShape.hasRank()) {
3619 if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
3620 "output[0]", "batch")) ||
3621 failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
3622 "output[0]", "channels")))
3623 return failure();
3624
3625 if (auto rowCount = getConstantScalarIntValue<int32_t>(getRowCount());
3626 succeeded(rowCount) && rowCount.value() > 0 &&
3627 ShapedType::isStatic(w)) {
3628 const int64_t expectedOutputRows = w * rowCount.value();
3629 if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
3630 outputDataShape.getDimSize(1) != expectedOutputRows)
3631 return emitOpError() << "requires output[0] dimension 1 to have size "
3632 << expectedOutputRows << ", got "
3633 << outputDataShape.getDimSize(1);
3634 }
3635 }
3636
3637 if (values.size() == 2) {
3638 const ShapeAdaptor valuesScaleShape(values[1].getType());
3639 if (valuesScaleShape.hasRank()) {
3640 if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
3641 "values[1]", "batch")) ||
3642 failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
3643 "values[1]", "rows")))
3644 return failure();
3645 multiplesOfC = valuesScaleShape.getDimSize(2);
3646 }
3647
3648 const ShapeAdaptor outputScaleShape(output[1].getType());
3649 if (outputScaleShape.hasRank()) {
3650 if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
3651 "output[1]", "batch")))
3652 return failure();
3653
3654 if (auto rowCount = getConstantScalarIntValue<int32_t>(getRowCount());
3655 succeeded(rowCount) && rowCount.value() > 0 &&
3656 ShapedType::isStatic(w)) {
3657 const int64_t expectedOutputRows = w * rowCount.value();
3658 if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
3659 outputScaleShape.getDimSize(1) != expectedOutputRows)
3660 return emitOpError() << "requires output[1] dimension 1 to have size "
3661 << expectedOutputRows << ", got "
3662 << outputScaleShape.getDimSize(1);
3663 }
3664
3665 if (ShapedType::isDynamic(multiplesOfC))
3666 multiplesOfC = outputScaleShape.getDimSize(2);
3667 else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
3668 multiplesOfC != outputScaleShape.getDimSize(2))
3669 return emitOpError()
3670 << "expected channels of output[1] to match size "
3671 << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
3672 }
3673
3674 if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
3675 multiplesOfC != c / blockSize)
3676 return emitOpError()
3677 << "expects channels of scale tensors to equal C/block_size (" << c
3678 << "/" << blockSize << "), got " << multiplesOfC;
3679 }
3680
3681 return success();
3682}
3683
3684LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3685 MLIRContext *context, ::std::optional<Location> location,
3686 ResizeOp::Adaptor adaptor,
3687 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3688 llvm::SmallVector<int64_t, 4> outputShape;
3689 outputShape.resize(4, ShapedType::kDynamic);
3690
3691 ShapeAdaptor inputShape(adaptor.getInput().getType());
3692 if (!inputShape.hasRank())
3693 return failure();
3694
3695 outputShape[0] = inputShape.getDimSize(0);
3696 outputShape[3] = inputShape.getDimSize(3);
3697 int64_t inputHeight = inputShape.getDimSize(1);
3698 int64_t inputWidth = inputShape.getDimSize(2);
3699
3700 if ((inputHeight == ShapedType::kDynamic) ||
3701 (inputWidth == ShapedType::kDynamic))
3702 return failure();
3703
3704 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3705 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
3706 scaleInt) ||
3707 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
3708 offsetInt) ||
3709 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
3710 borderInt)) {
3711 return failure();
3712 }
3713
3714 // Compute the output shape based on attributes: scale, offset, and border.
3715 const int64_t outputHeight =
3716 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3717 scaleInt[1]) +
3718 1;
3719
3720 const int64_t outputWidth =
3721 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3722 scaleInt[3]) +
3723 1;
3724
3725 if (outputHeight < 0 || outputWidth < 0) {
3726 return emitOptionalError(
3727 location,
3728 "calculated output height and width must be non-negative, "
3729 "got height = ",
3730 outputHeight, ", width = ", outputWidth);
3731 }
3732
3733 outputShape[1] = outputHeight;
3734 outputShape[2] = outputWidth;
3735 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3736 return success();
3737}
3738
3739LogicalResult tosa::ResizeOp::verify() {
3740 const Value input = getInput();
3741 const Value output = getOutput();
3742 const RankedTensorType inputType =
3743 llvm::dyn_cast<RankedTensorType>(input.getType());
3744 const RankedTensorType outputType =
3745 llvm::dyn_cast<RankedTensorType>(output.getType());
3746
3747 SmallVector<int64_t> scaleValues;
3748 SmallVector<int64_t> offsetValues;
3749 SmallVector<int64_t> borderValues;
3750 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3751 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3752 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3753 // Skip following checks if shape is not constant
3754 return success();
3755 }
3756
3757 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3758 return emitOpError("expect all scale values to be > 0, got ")
3759 << scaleValues;
3760
3761 const int64_t scaleYN = scaleValues[0];
3762 const int64_t scaleYD = scaleValues[1];
3763 const int64_t scaleXN = scaleValues[2];
3764 const int64_t scaleXD = scaleValues[3];
3765
3766 const int64_t offsetY = offsetValues[0];
3767 const int64_t offsetX = offsetValues[1];
3768
3769 const int64_t borderY = borderValues[0];
3770 const int64_t borderX = borderValues[1];
3771
3772 if (!inputType)
3773 return success();
3774 if (!outputType)
3775 return success();
3776
3777 const int64_t oh = outputType.getDimSize(1);
3778 const int64_t ow = outputType.getDimSize(2);
3779 const int64_t ih = inputType.getDimSize(1);
3780 const int64_t iw = inputType.getDimSize(2);
3781
3782 // Don't check with input height that could be broadcast (ih != 1)
3783 // since Linalg, a consumer of TOSA, expects broadcasting support
3784 // in resize to be available. Taking the cautious approach for now,
3785 // we can consider removing support for broadcasting later.
3786 if (ih != ShapedType::kDynamic && ih != 1) {
3787 const std::optional<int64_t> calculatedOutHeightMinusOne =
3788 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3789 if (!calculatedOutHeightMinusOne.has_value())
3790 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3791 "border_y ")
3792 << "to be wholly divisible by scale_y_d, got ((" << ih
3793 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3794 << ") / " << scaleYD;
3795 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3796 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3797 return emitOpError("calculated output height did not match expected: ")
3798 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3799 }
3800
3801 // Don't check with input width that could be broadcast (iw != 1)
3802 // since Linalg, a consumer of TOSA, expects broadcasting support
3803 // in resize to be available. Taking the cautious approach for now,
3804 // we can consider removing support for broadcasting later.
3805 if (iw != ShapedType::kDynamic && iw != 1) {
3806 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3807 const std::optional<int64_t> calculatedOutWidthMinusOne =
3808 idivCheck(scaledInWidth, scaleXD);
3809 if (!calculatedOutWidthMinusOne.has_value())
3810 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3811 "border_x ")
3812 << "to be wholly divisible by scale_x_d, got ((" << iw
3813 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3814 << ") / " << scaleXD;
3815 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3816 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3817 return emitOpError("calculated output width did not match expected: ")
3818 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3819 }
3820
3821 return success();
3822}
3823
3824LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3825 MLIRContext *context, ::std::optional<Location> location,
3826 ScatterOp::Adaptor adaptor,
3827 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3828 llvm::SmallVector<int64_t> outputShape;
3829 outputShape.resize(3, ShapedType::kDynamic);
3830
3831 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3832 if (valuesInShape.hasRank()) {
3833 outputShape[0] = valuesInShape.getDimSize(0);
3834 outputShape[1] = valuesInShape.getDimSize(1);
3835 outputShape[2] = valuesInShape.getDimSize(2);
3836 }
3837
3838 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3839 if (indicesShape.hasRank()) {
3840 if (outputShape[0] == ShapedType::kDynamic)
3841 outputShape[0] = indicesShape.getDimSize(0);
3842 }
3843
3844 ShapeAdaptor inputShape(adaptor.getInput().getType());
3845 if (inputShape.hasRank()) {
3846 if (outputShape[0] == ShapedType::kDynamic)
3847 outputShape[0] = inputShape.getDimSize(0);
3848 if (outputShape[2] == ShapedType::kDynamic)
3849 outputShape[2] = inputShape.getDimSize(2);
3850 }
3851
3852 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3853 return success();
3854}
3855
3856LogicalResult tosa::ScatterOp::verify() {
3857 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3858 /* outType = */ getValuesOut().getType())
3859 .failed() ||
3860 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3861 /* outType = */ getValuesOut().getType())
3862 .failed()) {
3863 return failure();
3864 }
3865
3866 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3867 const ShapeAdaptor indicesShape(getIndices().getType());
3868 const ShapeAdaptor inputShape(getInput().getType());
3869 const ShapeAdaptor outputShape(getValuesOut().getType());
3870
3871 int64_t n = ShapedType::kDynamic;
3872 int64_t k = ShapedType::kDynamic;
3873 int64_t w = ShapedType::kDynamic;
3874 int64_t c = ShapedType::kDynamic;
3875 if (valuesInShape.hasRank()) {
3876 n = valuesInShape.getDimSize(0);
3877 k = valuesInShape.getDimSize(1);
3878 c = valuesInShape.getDimSize(2);
3879 }
3880 if (indicesShape.hasRank()) {
3881 const int64_t indicesN = indicesShape.getDimSize(0);
3882 w = indicesShape.getDimSize(1);
3883 if (n == ShapedType::kDynamic)
3884 n = indicesN;
3885 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3886 return emitOpError() << "requires indices dimension 0 to have size " << n
3887 << ", got " << indicesN;
3888 }
3889 if (inputShape.hasRank()) {
3890 const int64_t inputN = inputShape.getDimSize(0);
3891 const int64_t inputW = inputShape.getDimSize(1);
3892 const int64_t inputC = inputShape.getDimSize(2);
3893 if (n == ShapedType::kDynamic)
3894 n = inputN;
3895 else if (inputN != ShapedType::kDynamic && n != inputN)
3896 return emitOpError() << "requires input dimension 0 to have size " << n
3897 << ", got " << inputN;
3898 if (w == ShapedType::kDynamic)
3899 w = inputW;
3900 else if (inputW != ShapedType::kDynamic && w != inputW)
3901 return emitOpError() << "requires input dimension 1 to have size " << w
3902 << ", got " << inputW;
3903
3904 if (c == ShapedType::kDynamic)
3905 c = inputC;
3906 else if (inputC != ShapedType::kDynamic && c != inputC)
3907 return emitOpError() << "requires input dimension 2 to have size " << c
3908 << ", got " << inputC;
3909 }
3910 if (outputShape.hasRank()) {
3911 const int64_t outputN = outputShape.getDimSize(0);
3912 const int64_t outputK = outputShape.getDimSize(1);
3913 const int64_t outputC = outputShape.getDimSize(2);
3914 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3915 n != outputN)
3916 return emitOpError() << "requires values_out dimension 0 to have size "
3917 << n << ", got " << outputN;
3918 if (k == ShapedType::kDynamic)
3919 k = outputK;
3920 else if (outputK != ShapedType::kDynamic && k != outputK)
3921 return emitOpError() << "requires values_out dimension 1 to have size "
3922 << k << ", got " << outputK;
3923 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3924 c != outputC)
3925 return emitOpError() << "requires values_out dimension 2 to have size "
3926 << c << ", got " << outputC;
3927 }
3928 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3929 return emitOpError() << "requires dimensions K >= W, got K=" << k
3930 << " and W=" << w;
3931
3932 return success();
3933}
3934
3935static LogicalResult ReduceInferReturnTypes(
3936 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3937 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3938 int64_t axisVal = axis.getValue().getSExtValue();
3939 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3940 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3941 return success();
3942 }
3943
3944 SmallVector<int64_t> outputShape;
3945 operandShape.getDims(outputShape);
3946 outputShape[axisVal] = 1;
3947 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3948 return success();
3949}
3950
3951#define COMPATIBLE_RETURN_TYPES(OP) \
3952 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3953 if (l.size() != r.size() || l.size() != 1) \
3954 return false; \
3955 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3956 return false; \
3957 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3958 }
3959
3960#define REDUCE_SHAPE_INFER(OP) \
3961 LogicalResult OP::inferReturnTypeComponents( \
3962 MLIRContext *context, ::std::optional<Location> location, \
3963 OP::Adaptor adaptor, \
3964 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3965 Type inputType = \
3966 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3967 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3968 const Properties &prop = adaptor.getProperties(); \
3969 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3970 inferredReturnShapes); \
3971 } \
3972 COMPATIBLE_RETURN_TYPES(OP)
3973
3974REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3975REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3976REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3977REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3978REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3979REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3980#undef REDUCE_SHAPE_INFER
3981COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3982#undef COMPATIBLE_RETURN_TYPES
3983
3984template <typename T>
3985static LogicalResult verifyReduceOp(T op) {
3986 // All TOSA reduce Ops have input, output and axis.
3987 TensorType inputType = op.getInput().getType();
3988 TensorType outputType = op.getOutput().getType();
3989 int32_t reduceAxis = op.getAxis();
3990
3991 if (reduceAxis < 0) {
3992 op.emitOpError("reduce axis must not be negative");
3993 return failure();
3994 }
3995 if (inputType.hasRank()) {
3996 int64_t inputRank = inputType.getRank();
3997 // We allow for a special case where the input/output shape has rank 0 and
3998 // axis is also 0.
3999 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
4000 op.emitOpError("expect input tensor rank (")
4001 << inputRank << ") to be larger than reduce axis (" << reduceAxis
4002 << ")";
4003 return failure();
4004 }
4005 }
4006 if (outputType.hasRank()) {
4007 int64_t outputRank = outputType.getRank();
4008 if (inputType.hasRank() && outputRank != inputType.getRank()) {
4009 op.emitOpError(
4010 "expect output tensor rank to be equal to input tensor rank");
4011 return failure();
4012 }
4013 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
4014 op.emitOpError("expect output tensor rank (")
4015 << outputRank << ") to be larger than reduce axis (" << reduceAxis
4016 << ")";
4017 return failure();
4018 }
4019 // We can only verify the reduced dimension size to be 1 if this is not
4020 // the special case of output rank == 0.
4021 if (outputRank != 0) {
4022 auto outputShape = outputType.getShape();
4023 if (!outputType.isDynamicDim(reduceAxis) &&
4024 outputShape[reduceAxis] != 1) {
4025 op.emitOpError("expect reduced dimension size to be 1, got ")
4026 << outputShape[reduceAxis];
4027 return failure();
4028 }
4029 }
4030 }
4031 return success();
4032}
4033
4034LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
4035LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
4036LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
4037LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
4038LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
4039LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
4040
4041static LogicalResult NAryInferReturnTypes(
4042 const ValueShapeRange &operands,
4043 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4045 if (resolveBroadcastShape(operands, outShape).failed()) {
4046 inferredReturnShapes.push_back(ShapedTypeComponents());
4047 } else {
4048 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
4049 }
4050 return success();
4051}
4052
4053#define NARY_SHAPE_INFER(OP) \
4054 LogicalResult OP::inferReturnTypeComponents( \
4055 MLIRContext *context, ::std::optional<Location> location, \
4056 ValueShapeRange operands, DictionaryAttr attributes, \
4057 PropertyRef properties, RegionRange regions, \
4058 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
4059 return NAryInferReturnTypes(operands, inferredReturnShapes); \
4060 }
4061
4062NARY_SHAPE_INFER(tosa::AbsOp)
4063NARY_SHAPE_INFER(tosa::AddOp)
4064NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
4065NARY_SHAPE_INFER(tosa::BitwiseAndOp)
4066NARY_SHAPE_INFER(tosa::BitwiseOrOp)
4067NARY_SHAPE_INFER(tosa::BitwiseXorOp)
4068NARY_SHAPE_INFER(tosa::BitwiseNotOp)
4069NARY_SHAPE_INFER(tosa::CastOp)
4070NARY_SHAPE_INFER(tosa::CeilOp)
4071NARY_SHAPE_INFER(tosa::ClampOp)
4072NARY_SHAPE_INFER(tosa::ClzOp)
4073NARY_SHAPE_INFER(tosa::CosOp)
4074NARY_SHAPE_INFER(tosa::ExpOp)
4075NARY_SHAPE_INFER(tosa::FloorOp)
4076NARY_SHAPE_INFER(tosa::GreaterEqualOp)
4077NARY_SHAPE_INFER(tosa::GreaterOp)
4078NARY_SHAPE_INFER(tosa::IdentityOp)
4079NARY_SHAPE_INFER(tosa::IntDivOp)
4080NARY_SHAPE_INFER(tosa::LogOp)
4081NARY_SHAPE_INFER(tosa::LogicalAndOp)
4082NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
4083NARY_SHAPE_INFER(tosa::LogicalNotOp)
4084NARY_SHAPE_INFER(tosa::LogicalOrOp)
4085NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
4086NARY_SHAPE_INFER(tosa::LogicalXorOp)
4087NARY_SHAPE_INFER(tosa::MaximumOp)
4088NARY_SHAPE_INFER(tosa::MinimumOp)
4089NARY_SHAPE_INFER(tosa::PowOp)
4090NARY_SHAPE_INFER(tosa::ReciprocalOp)
4091NARY_SHAPE_INFER(tosa::ReverseOp)
4092NARY_SHAPE_INFER(tosa::RsqrtOp)
4093NARY_SHAPE_INFER(tosa::SinOp)
4094NARY_SHAPE_INFER(tosa::SelectOp)
4095NARY_SHAPE_INFER(tosa::SubOp)
4096NARY_SHAPE_INFER(tosa::TanhOp)
4097NARY_SHAPE_INFER(tosa::ErfOp)
4098NARY_SHAPE_INFER(tosa::SigmoidOp)
4099#undef PRED_SHAPE_INFER
4100
4101LogicalResult tosa::NegateOp::inferReturnTypeComponents(
4102 MLIRContext *context, ::std::optional<Location> location,
4103 NegateOp::Adaptor adaptor,
4104 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4105 ShapeAdaptor inputShape(adaptor.getInput1().getType());
4106 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4107 return success();
4108}
4109
4110LogicalResult tosa::NegateOp::verify() {
4111 // Verify same element type
4112 const Type input1Type = getInput1().getType();
4113 const Type outputType = getOutput().getType();
4114 if (verifySameElementTypes(*this, input1Type, outputType).failed())
4115 return failure();
4116
4117 // Verify same shape
4118 const SmallVector<Type, 2> types = {input1Type, outputType};
4119 if (failed(verifyCompatibleShapes(types)))
4120 return emitOpError() << "requires the same shape for input1 and output";
4121
4122 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
4123 const Type input1ZpEType =
4124 getStorageElementTypeOrSelf(getInput1Zp().getType());
4125 if (input1EType != input1ZpEType) {
4126 return emitOpError("expect both input1 and its zero point are the same "
4127 "element type, got ")
4128 << input1EType << " and " << input1ZpEType;
4129 }
4130 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
4131 const Type outputZpEType =
4132 getStorageElementTypeOrSelf(getOutputZp().getType());
4133 if (outputEType != outputZpEType) {
4134 return emitOpError("expect both output and its zero point are the same "
4135 "element type, got ")
4136 << outputEType << " and " << outputZpEType;
4137 }
4138
4139 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
4140 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
4141 return failure();
4142
4143 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4144 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4145 return failure();
4146
4147 return success();
4148}
4149
4150static LogicalResult poolingInferReturnTypes(
4151 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
4153 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4154 llvm::SmallVector<int64_t> outputShape;
4155 outputShape.resize(4, ShapedType::kDynamic);
4156
4157 // We only know the rank if the input type is unranked.
4158 if (!inputShape.hasRank()) {
4159 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4160 return success();
4161 }
4162
4163 // Batch and number of channels are identical for pooling layer.
4164 outputShape[0] = inputShape.getDimSize(0);
4165 outputShape[3] = inputShape.getDimSize(3);
4166
4167 int64_t height = inputShape.getDimSize(1);
4168 int64_t width = inputShape.getDimSize(2);
4169
4170 if (ShapedType::isStatic(height)) {
4171 int64_t padded = height + pad[0] + pad[1] - kernel[0];
4172 outputShape[1] = padded / stride[0] + 1;
4173 }
4174
4175 if (ShapedType::isStatic(width)) {
4176 int64_t padded = width + pad[2] + pad[3] - kernel[1];
4177 outputShape[2] = padded / stride[1] + 1;
4178 }
4179
4180 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4181 return success();
4182}
4183
4184template <typename AdaptorT>
4186
4188protected:
4189 static void updateIfDynamic(int64_t &current, int64_t candidate) {
4190 if (ShapedType::isDynamic(current))
4191 current = candidate;
4192 }
4193};
4194
4195template <>
4196class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
4197 : public ConvInferShapeAdaptorBase {
4198public:
4199 explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
4200 : adaptor(adaptor) {}
4201
4203 SmallVectorImpl<int64_t> &inputSpatial) {
4204 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4205 if (!inputShape.hasRank())
4206 return;
4207
4208 const int64_t outputBatch = inputShape.getDimSize(0);
4209 const int64_t inputHeight = inputShape.getDimSize(1);
4210 const int64_t inputWidth = inputShape.getDimSize(2);
4211
4212 outputShape[0] = outputBatch;
4213 inputSpatial[0] = inputHeight;
4214 inputSpatial[1] = inputWidth;
4215 }
4216
4218 SmallVectorImpl<int64_t> &weightSpatial) {
4219 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4220 if (!weightShape.hasRank())
4221 return;
4222
4223 const int64_t outputChannels = weightShape.getDimSize(0);
4224 const int64_t kernelHeight = weightShape.getDimSize(1);
4225 const int64_t kernelWidth = weightShape.getDimSize(2);
4226
4227 outputShape[3] = outputChannels;
4228 weightSpatial[0] = kernelHeight;
4229 weightSpatial[1] = kernelWidth;
4230 }
4231
4232 int64_t getNumSpatialDims() const { return 2; }
4233 int64_t getOutputRank() const { return 4; }
4234
4236 SmallVector<int64_t> &strideValues,
4237 SmallVector<int64_t> &dilationValues) {
4238 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4239 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4240 dilationValues.assign(adaptor.getDilation().begin(),
4241 adaptor.getDilation().end());
4242 return success();
4243 }
4244
4245private:
4246 Conv2DOp::Adaptor adaptor;
4247};
4248
4249template <>
4250class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
4251 : public ConvInferShapeAdaptorBase {
4252public:
4253 explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
4254 : adaptor(adaptor) {}
4255
4257 SmallVectorImpl<int64_t> &inputSpatial) {
4258 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
4259 if (inputDataShape.hasRank()) {
4260 const int64_t outputBatch = inputDataShape.getDimSize(0);
4261 const int64_t inputHeight = inputDataShape.getDimSize(1);
4262 const int64_t inputWidth = inputDataShape.getDimSize(2);
4263
4264 outputShape[0] = outputBatch;
4265 inputSpatial[0] = inputHeight;
4266 inputSpatial[1] = inputWidth;
4267 }
4268
4269 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
4270 if (!inputScaleShape.hasRank())
4271 return;
4272
4273 const int64_t scaleBatch = inputScaleShape.getDimSize(0);
4274 const int64_t scaleHeight = inputScaleShape.getDimSize(1);
4275 const int64_t scaleWidth = inputScaleShape.getDimSize(2);
4276
4277 updateIfDynamic(outputShape[0], scaleBatch);
4278 updateIfDynamic(inputSpatial[0], scaleHeight);
4279 updateIfDynamic(inputSpatial[1], scaleWidth);
4280 }
4281
4283 SmallVectorImpl<int64_t> &weightSpatial) {
4284 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
4285 if (weightDataShape.hasRank()) {
4286 const int64_t outputChannels = weightDataShape.getDimSize(0);
4287 const int64_t kernelHeight = weightDataShape.getDimSize(1);
4288 const int64_t kernelWidth = weightDataShape.getDimSize(2);
4289
4290 outputShape[3] = outputChannels;
4291 weightSpatial[0] = kernelHeight;
4292 weightSpatial[1] = kernelWidth;
4293 }
4294
4295 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
4296 if (!weightScaleShape.hasRank())
4297 return;
4298
4299 const int64_t scaleOutputChannels = weightScaleShape.getDimSize(0);
4300 const int64_t scaleKernelHeight = weightScaleShape.getDimSize(1);
4301 const int64_t scaleKernelWidth = weightScaleShape.getDimSize(2);
4302
4303 updateIfDynamic(outputShape[3], scaleOutputChannels);
4304 updateIfDynamic(weightSpatial[0], scaleKernelHeight);
4305 updateIfDynamic(weightSpatial[1], scaleKernelWidth);
4306 }
4307
4308 int64_t getNumSpatialDims() const { return 2; }
4309 int64_t getOutputRank() const { return 4; }
4310
4312 SmallVector<int64_t> &strideValues,
4313 SmallVector<int64_t> &dilationValues) {
4314 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
4315 padValues) ||
4316 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4317 strideValues) ||
4318 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
4319 dilationValues))
4320 return failure();
4321 return success();
4322 }
4323
4324private:
4325 Conv2DBlockScaledOp::Adaptor adaptor;
4326};
4327
4328template <>
4329class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
4330 : public ConvInferShapeAdaptorBase {
4331public:
4332 explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
4333 : adaptor(adaptor) {}
4334
4336 SmallVectorImpl<int64_t> &inputSpatial) {
4337 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4338 if (!inputShape.hasRank())
4339 return;
4340
4341 const int64_t outputBatch = inputShape.getDimSize(0);
4342 const int64_t inputDepth = inputShape.getDimSize(1);
4343 const int64_t inputHeight = inputShape.getDimSize(2);
4344 const int64_t inputWidth = inputShape.getDimSize(3);
4345
4346 outputShape[0] = outputBatch;
4347 inputSpatial[0] = inputDepth;
4348 inputSpatial[1] = inputHeight;
4349 inputSpatial[2] = inputWidth;
4350 }
4351
4353 SmallVectorImpl<int64_t> &weightSpatial) {
4354 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4355 if (!weightShape.hasRank())
4356 return;
4357
4358 const int64_t outputChannels = weightShape.getDimSize(0);
4359 const int64_t kernelDepth = weightShape.getDimSize(1);
4360 const int64_t kernelHeight = weightShape.getDimSize(2);
4361 const int64_t kernelWidth = weightShape.getDimSize(3);
4362
4363 outputShape[4] = outputChannels;
4364 weightSpatial[0] = kernelDepth;
4365 weightSpatial[1] = kernelHeight;
4366 weightSpatial[2] = kernelWidth;
4367 }
4368
4369 int64_t getNumSpatialDims() const { return 3; }
4370 int64_t getOutputRank() const { return 5; }
4371
4373 SmallVector<int64_t> &strideValues,
4374 SmallVector<int64_t> &dilationValues) {
4375 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4376 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4377 dilationValues.assign(adaptor.getDilation().begin(),
4378 adaptor.getDilation().end());
4379 return success();
4380 }
4381
4382private:
4383 Conv3DOp::Adaptor adaptor;
4384};
4385
4386template <typename AdaptorT>
4388 AdaptorT adaptor,
4389 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4390 ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
4391 llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
4392 ShapedType::kDynamic);
4393 llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
4394 ShapedType::kDynamic);
4395 llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
4396 ShapedType::kDynamic);
4397
4398 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
4399 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
4400
4401 const ShapeAdaptor biasShape = adaptor.getBias().getType();
4402 if (biasShape.hasRank()) {
4403 const int64_t biasSize = biasShape.getDimSize(0);
4404 if (biasSize != 1) {
4405 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
4406 outputShape[outputChannelDim] =
4407 ShapedType::isDynamic(outputShape[outputChannelDim])
4408 ? biasSize
4409 : outputShape[outputChannelDim];
4410 }
4411 }
4412
4413 SmallVector<int64_t> padValues;
4414 SmallVector<int64_t> strideValues;
4415 SmallVector<int64_t> dilationValues;
4416 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
4417 dilationValues))) {
4418 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4419 return success();
4420 }
4421
4422 for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
4423 if (!ShapedType::isStatic(inputSpatial[dim]) ||
4424 !ShapedType::isStatic(weightSpatial[dim]))
4425 continue;
4426 const int64_t inputSize =
4427 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
4428 const int64_t filterSize =
4429 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
4430 const int64_t unstridedResult = inputSize - filterSize + 1;
4431 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
4432 }
4433
4434 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4435 return success();
4436}
4437
4438LogicalResult Conv2DOp::inferReturnTypeComponents(
4439 MLIRContext *context, ::std::optional<Location> location,
4440 Conv2DOp::Adaptor adaptor,
4441 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4442 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4443}
4444
4445LogicalResult Conv2DOp::verify() {
4446 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4447 verifyConvOpErrorIf(*this).failed())
4448 return failure();
4449 return success();
4450}
4451
4452LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
4453 MLIRContext *context, ::std::optional<Location> location,
4454 Conv2DBlockScaledOp::Adaptor adaptor,
4455 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4456 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4457}
4458
4459LogicalResult Conv2DBlockScaledOp::verify() {
4460 if (failed(verifySameElementTypes(*this, getInputData().getType(),
4461 getWeightData().getType(), "input_data",
4462 "weight_data")) ||
4463 failed(verifySameElementTypes(*this, getInputScale().getType(),
4464 getWeightScale().getType(), "input_scale",
4465 "weight_scale")) ||
4466 failed(verifySameElementTypes(*this, getBias().getType(),
4467 getOutput().getType(), "bias", "output")))
4468 return failure();
4469
4470 // Verify input shape compatibility
4471 int64_t N = ShapedType::kDynamic;
4472 int64_t IH = ShapedType::kDynamic;
4473 int64_t IW = ShapedType::kDynamic;
4474 int64_t IC = ShapedType::kDynamic;
4475 int64_t multiplesOfIC = ShapedType::kDynamic;
4476 int64_t OC = ShapedType::kDynamic;
4477 int64_t KH = ShapedType::kDynamic;
4478 int64_t KW = ShapedType::kDynamic;
4479
4480 const ShapeAdaptor inputDataShape(getInputData().getType());
4481 if (inputDataShape.hasRank()) {
4482 N = inputDataShape.getDimSize(0);
4483 IH = inputDataShape.getDimSize(1);
4484 IW = inputDataShape.getDimSize(2);
4485 IC = inputDataShape.getDimSize(3);
4486 }
4487
4488 const ShapeAdaptor inputScaleShape(getInputScale().getType());
4489 if (inputScaleShape.hasRank()) {
4490 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
4491 "input_scale", "batch size")) ||
4492 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
4493 "input_scale", "input height")) ||
4494 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
4495 "input_scale", "input width")))
4496 return failure();
4497 multiplesOfIC = inputScaleShape.getDimSize(3);
4498 }
4499
4500 const ShapeAdaptor weightDataShape(getWeightData().getType());
4501 if (weightDataShape.hasRank()) {
4502 OC = weightDataShape.getDimSize(0);
4503 KH = weightDataShape.getDimSize(1);
4504 KW = weightDataShape.getDimSize(2);
4505 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
4506 "weight_data", "input channels")))
4507 return failure();
4508 }
4509
4510 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
4511 if (weightScaleShape.hasRank()) {
4512 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
4513 "weight_scale", "output channels")) ||
4514 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
4515 "weight_scale", "kernel height")) ||
4516 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
4517 "weight_scale", "kernel width")) ||
4518 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
4519 weightScaleShape.getDimSize(3),
4520 "weight_scale", "input channel blocks")))
4521 return failure();
4522 }
4523
4524 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
4525 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4526 return emitOpError("expect block size to be 32, got ") << blockSize;
4527 // Verify IC is a multiple of block size
4528 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
4529 return emitOpError("expect IC to be a multiple of block size, got IC=")
4530 << IC << ", block_size=" << blockSize;
4531
4532 // Verify multiplesOfIC is IC / block size
4533 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
4534 multiplesOfIC != IC / blockSize)
4535 return emitOpError(
4536 "expect scale operands dimension 2 to equal IC/block_size (")
4537 << IC << "/" << blockSize << ")"
4538 << ", got " << multiplesOfIC;
4539
4540 // Verify pad/stride/dilation values
4541 SmallVector<int64_t> padValues;
4542 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
4543 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
4544 return emitOpError("expect all padding values to be >= 0, got ")
4545 << padValues;
4546 }
4547
4548 SmallVector<int64_t> strideValues;
4549 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
4550 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
4551 return emitOpError("expect all stride values to be >= 1, got ")
4552 << strideValues;
4553 }
4554
4555 SmallVector<int64_t> dilationValues;
4556 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
4557 dilationValues)) {
4558 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
4559 return emitOpError("expect all dilation values to be >= 1, got ")
4560 << dilationValues;
4561 }
4562
4563 // Verify output shape compatibility
4564 const ShapeAdaptor outputShape(getOutput().getType());
4565 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
4566 outputShape.hasRank()) {
4567 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
4568 padValues[0], padValues[1], strideValues[0],
4569 dilationValues[0], "height", "y", "top",
4570 "bottom")) ||
4571 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
4572 padValues[2], padValues[3], strideValues[1],
4573 dilationValues[1], "width", "x", "left",
4574 "right")))
4575 return failure();
4576 }
4577
4578 // Verify bias
4579 const ShapeAdaptor biasShape(getBias().getType());
4580 if (biasShape.hasRank() && outputShape.hasRank()) {
4581 const int64_t biasChannels = biasShape.getDimSize(0);
4582 const int64_t outputChannels =
4583 outputShape.getDimSize(outputShape.getRank() - 1);
4584 if (biasChannels == ShapedType::kDynamic ||
4585 outputChannels == ShapedType::kDynamic)
4586 // Skip following checks if biasChannels or outputChannels is dynamic dim
4587 return success();
4588
4589 if (biasChannels != outputChannels && biasChannels != 1)
4590 return emitOpError(
4591 "bias channels expected to be equal to output channels (")
4592 << outputChannels << ") or 1, got " << biasChannels;
4593 }
4594
4595 return success();
4596}
4597
4598LogicalResult Conv3DOp::inferReturnTypeComponents(
4599 MLIRContext *context, ::std::optional<Location> location,
4600 Conv3DOp::Adaptor adaptor,
4601 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4602 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4603}
4604
4605LogicalResult Conv3DOp::verify() {
4606 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4607 verifyConvOpErrorIf(*this).failed())
4608 return failure();
4609 return success();
4610}
4611
4612LogicalResult AvgPool2dOp::inferReturnTypeComponents(
4613 MLIRContext *context, ::std::optional<Location> location,
4614 AvgPool2dOp::Adaptor adaptor,
4615 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4616 ShapeAdaptor inputShape(adaptor.getInput().getType());
4617 const Properties &prop = adaptor.getProperties();
4618 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
4619 inferredReturnShapes);
4620}
4621
4622LogicalResult AvgPool2dAdaptiveOp::inferReturnTypeComponents(
4623 MLIRContext *context, ::std::optional<Location> location,
4624 AvgPool2dAdaptiveOp::Adaptor adaptor,
4625 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4626 ShapeAdaptor inputShape(adaptor.getInput().getType());
4627
4628 llvm::SmallVector<int64_t> kernelValues;
4629 llvm::SmallVector<int64_t> strideValues;
4630 llvm::SmallVector<int64_t> padValues;
4631 if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
4632 kernelValues) &&
4633 tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4634 strideValues) &&
4635 tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
4636 return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
4637 padValues, inferredReturnShapes);
4638 }
4639
4640 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4641 if (inputShape.hasRank()) {
4642 // Keep N & C as pooling only changes H & W.
4643 outputShape[0] = inputShape.getDimSize(0);
4644 outputShape[3] = inputShape.getDimSize(3);
4645 }
4646
4647 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4648 return success();
4649}
4650
4651LogicalResult MaxPool2dOp::inferReturnTypeComponents(
4652 MLIRContext *context, ::std::optional<Location> location,
4653 MaxPool2dOp::Adaptor adaptor,
4654 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4655 ShapeAdaptor inputShape(adaptor.getInput().getType());
4656 const Properties &prop = adaptor.getProperties();
4657 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
4658 inferredReturnShapes);
4659}
4660
4661LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
4662 MLIRContext *context, ::std::optional<Location> location,
4663 MaxPool2dAdaptiveOp::Adaptor adaptor,
4664 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4665 ShapeAdaptor inputShape(adaptor.getInput().getType());
4666
4667 llvm::SmallVector<int64_t> kernelValues;
4668 llvm::SmallVector<int64_t> strideValues;
4669 llvm::SmallVector<int64_t> padValues;
4670 if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
4671 kernelValues) &&
4672 tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4673 strideValues) &&
4674 tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
4675 return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
4676 padValues, inferredReturnShapes);
4677 }
4678
4679 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4680 if (inputShape.hasRank()) {
4681 outputShape[0] = inputShape.getDimSize(0);
4682 outputShape[3] = inputShape.getDimSize(3);
4683 }
4684 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4685 return success();
4686}
4687
4688LogicalResult MaxPool2dOp::verify() {
4689 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
4690 /* outType = */ getOutput().getType())))
4691 return failure();
4692
4693 if (failed(verifyPoolingOp(*this)))
4694 return failure();
4695
4696 return success();
4697}
4698
4699LogicalResult MaxPool2dAdaptiveOp::verify() {
4700 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
4701 /* outType = */ getOutput().getType())))
4702 return failure();
4703
4704 AdaptivePoolingConstShapeValues values;
4706
4707 if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
4708 values.pad, getInput(), getOutput())))
4709 return failure();
4710
4711 return success();
4712}
4713
4714LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
4715 MLIRContext *context, ::std::optional<Location> location,
4716 DepthwiseConv2DOp::Adaptor adaptor,
4717 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4718 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4719
4720 int64_t inputWidth = ShapedType::kDynamic;
4721 int64_t inputHeight = ShapedType::kDynamic;
4722 int64_t inputChannels = ShapedType::kDynamic;
4723
4724 int64_t weightWidth = ShapedType::kDynamic;
4725 int64_t weightHeight = ShapedType::kDynamic;
4726 int64_t depthChannels = ShapedType::kDynamic;
4727
4728 // Input shape describes input width/height and batch.
4729 ShapeAdaptor inputShape(adaptor.getInput().getType());
4730 if (inputShape.hasRank()) {
4731 outputShape[0] = inputShape.getDimSize(0);
4732 inputHeight = inputShape.getDimSize(1);
4733 inputWidth = inputShape.getDimSize(2);
4734 inputChannels = inputShape.getDimSize(3);
4735 }
4736
4737 // Weight shapes describes the filter width/height and the output channels.
4738 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4739 if (weightShape.hasRank()) {
4740 weightHeight = weightShape.getDimSize(0);
4741 weightWidth = weightShape.getDimSize(1);
4742 inputChannels = ShapedType::isDynamic(inputChannels)
4743 ? weightShape.getDimSize(2)
4744 : inputChannels;
4745 depthChannels = weightShape.getDimSize(3);
4746 }
4747
4748 // If both inputChannels and depthChannels are available we can determine
4749 // the output channels.
4750 if (ShapedType::isStatic(inputChannels) &&
4751 ShapedType::isStatic(depthChannels)) {
4752 outputShape[3] = inputChannels * depthChannels;
4753 }
4754
4755 // Bias shape can describe the output channels.
4756 ShapeAdaptor biasShape(adaptor.getBias().getType());
4757 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4758 int64_t bc = biasShape.getDimSize(0);
4759 if (bc != ShapedType::kDynamic && bc != 1)
4760 outputShape[3] = bc;
4761 }
4762
4763 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4764 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4765 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4766
4767 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4768 int64_t inputSize = inputHeight + padding[0] + padding[1];
4769 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4770 int64_t unstridedResult = inputSize - filterSize + 1;
4771 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4772 }
4773
4774 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4775 int64_t inputSize = inputWidth + padding[2] + padding[3];
4776 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4777 int64_t unstridedResult = inputSize - filterSize + 1;
4778 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4779 }
4780
4781 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4782 return success();
4783}
4784
4785LogicalResult DepthwiseConv2DOp::verify() {
4786 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4787 verifyConvOpErrorIf(*this).failed())
4788 return failure();
4789 return success();
4790}
4791
4792LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4793 MLIRContext *context, ::std::optional<Location> location,
4794 TransposeConv2DOp::Adaptor adaptor,
4795 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4796 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4797
4798 int64_t inputWidth = ShapedType::kDynamic;
4799 int64_t inputHeight = ShapedType::kDynamic;
4800 int64_t weightWidth = ShapedType::kDynamic;
4801 int64_t weightHeight = ShapedType::kDynamic;
4802
4803 // Input shape describes input width/height and batch.
4804 ShapeAdaptor inputShape(adaptor.getInput().getType());
4805 if (inputShape.hasRank()) {
4806 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4807 ? inputShape.getDimSize(0)
4808 : outputShape[0];
4809 inputHeight = inputShape.getDimSize(1);
4810 inputWidth = inputShape.getDimSize(2);
4811 }
4812
4813 // Weight shapes describes the filter width/height and the output channels.
4814 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4815 if (weightShape.hasRank()) {
4816 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4817 ? weightShape.getDimSize(0)
4818 : outputShape[3];
4819 weightHeight = weightShape.getDimSize(1);
4820 weightWidth = weightShape.getDimSize(2);
4821 }
4822
4823 // Bias shape can describe the output channels.
4824 ShapeAdaptor biasShape(adaptor.getBias().getType());
4825 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4826 int64_t bc = biasShape.getDimSize(0);
4827 if (bc != ShapedType::kDynamic && bc != 1)
4828 outputShape[3] = bc;
4829 }
4830
4831 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4832 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4833
4834 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4835 int64_t calculateSize =
4836 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4837 outputShape[1] =
4838 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4839 }
4840
4841 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4842 int64_t calculateSize =
4843 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4844 outputShape[2] =
4845 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4846 }
4847
4848 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4849 return success();
4850}
4851
4852LogicalResult TransposeConv2DOp::verify() {
4853 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
4854 return failure();
4855
4856 const llvm::ArrayRef<int64_t> strides = getStride();
4857 const int64_t strideY = strides[0];
4858 const int64_t strideX = strides[1];
4859
4860 if (strideY < 1 || strideX < 1)
4861 return emitOpError("expect all stride values to be >= 1, got [")
4862 << strides << "]";
4863
4864 const auto checkPadAgainstKernelDim =
4865 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4866 llvm::StringRef kernelDimName) -> LogicalResult {
4867 if (padValue <= -kernelDimSize)
4868 return emitOpError("expected ")
4869 << padName << " > -" << kernelDimName << ", but got: " << padName
4870 << "=" << padValue << " and " << kernelDimName << "="
4871 << kernelDimSize;
4872 return success();
4873 };
4874
4875 const llvm::ArrayRef<int64_t> padding = getOutPad();
4876 const int64_t outPadTop = padding[0];
4877 const int64_t outPadBottom = padding[1];
4878 const int64_t outPadLeft = padding[2];
4879 const int64_t outPadRight = padding[3];
4880
4881 const auto weightType =
4882 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4883
4884 if (weightType) {
4885 const int64_t kernelHeight = weightType.getDimSize(1);
4886 if (ShapedType::isStatic(kernelHeight)) {
4887 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4888 "out_pad_top", "KH")))
4889 return failure();
4890
4891 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4892 "out_pad_bottom", "KH")))
4893 return failure();
4894 }
4895
4896 const int64_t kernelWidth = weightType.getDimSize(2);
4897 if (ShapedType::isStatic(kernelWidth)) {
4898 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4899 "out_pad_left", "KW")))
4900 return failure();
4901
4902 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4903 "out_pad_right", "KW")))
4904 return failure();
4905 }
4906 }
4907
4908 // Rest of the checks depend on the output type being a RankedTensorType
4909 const auto outputType =
4910 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4911 if (!outputType)
4912 return success();
4913
4914 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4915 if (inputType && weightType) {
4916 const int64_t inputHeight = inputType.getDimSize(1);
4917 const int64_t kernelHeight = weightType.getDimSize(1);
4918 const int64_t outputHeight = outputType.getDimSize(1);
4919
4920 if (ShapedType::isStatic(inputHeight) &&
4921 ShapedType::isStatic(outputHeight)) {
4922 if (outputHeight !=
4923 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4924 return emitOpError(
4925 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4926 "+ out_pad_top + out_pad_bottom + KH, but got ")
4927 << outputHeight << " != (" << inputHeight << " - 1) * "
4928 << strideY << " + " << outPadTop << " + " << outPadBottom
4929 << " + " << kernelHeight;
4930 }
4931
4932 const int64_t inputWidth = inputType.getDimSize(2);
4933 const int64_t kernelWidth = weightType.getDimSize(2);
4934 const int64_t outputWidth = outputType.getDimSize(2);
4935
4936 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4937 if (outputWidth !=
4938 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4939 return emitOpError(
4940 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4941 "+ out_pad_left + out_pad_right + KW, but got ")
4942 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4943 << " + " << outPadLeft << " + " << outPadRight << " + "
4944 << kernelWidth;
4945 }
4946 }
4947
4948 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4949
4950 if (!biasType)
4951 return success();
4952
4953 const int64_t biasChannels = biasType.getDimSize(0);
4954
4955 // Skip further checks if bias is dynamic
4956 if (biasChannels == ShapedType::kDynamic)
4957 return success();
4958
4959 const int64_t outputChannels = outputType.getDimSize(3);
4960 if (!ShapedType::isDynamic(outputChannels) &&
4961 biasChannels != outputChannels && biasChannels != 1)
4962 return emitOpError(
4963 "bias channels expected to be equal to output channels (")
4964 << outputChannels << ") or 1, got " << biasChannels;
4965
4966 return success();
4967}
4968
4969LogicalResult RescaleOp::verify() {
4970 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
4971 auto inputElementType =
4972 getStorageElementTypeOrSelf(inputType.getElementType());
4973 if (!mlir::isa<IntegerType>(inputElementType)) {
4974 emitOpError("expect input to have integer element type, got ")
4975 << inputElementType;
4976 return failure();
4977 }
4978
4979 const auto outputType = llvm::cast<ShapedType>(getOutput().getType());
4980 auto outputElementType =
4981 getStorageElementTypeOrSelf(outputType.getElementType());
4982 if (!mlir::isa<IntegerType>(outputElementType)) {
4983 emitOpError("expect output to have integer element type, got ")
4984 << outputElementType;
4985 return failure();
4986 }
4987
4988 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4989 .failed())
4990 return failure();
4991
4992 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4993 .failed())
4994 return failure();
4995
4996 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4997 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4998 return failure();
4999
5000 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
5001 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
5002 return failure();
5003
5004 const auto multiplierType = llvm::cast<ShapedType>(getMultiplier().getType());
5005 // multiplier element type must be i32 for scale32 = true
5006 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
5007 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
5008 << multiplierType.getElementType();
5009 return failure();
5010 }
5011
5012 // multiplier element type must be i16 for scale32 = false
5013 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
5015 "expect i16 element type for multiplier for scale32=false, got ")
5016 << multiplierType.getElementType();
5017 return failure();
5018 }
5019
5020 if (!inputType.hasRank())
5021 return success();
5022
5023 // multiplier/shift must have shape = {numChannels},
5024 // where numChannel is 1 if per_channel = false
5025 // otherwise numChannel is dimension in input shape's last axis
5026 int64_t numChannels = 1;
5027 if (getPerChannel()) {
5028 if (inputType.getRank() < 1) {
5029 emitOpError("requires input to be at least rank 1 when per_channel is "
5030 "true, but got rank ")
5031 << inputType.getRank();
5032 return failure();
5033 }
5034 numChannels = inputType.getDimSize(inputType.getRank() - 1);
5035 }
5036
5037 if (outputType.hasRank()) {
5039 getOperation(), outputType, inputType.getShape())))
5040 return failure();
5041 }
5042
5043 if (multiplierType.hasRank()) {
5044 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
5045 // multiplier input has rank 1 by dialect definition
5046 if (multiplierShape[0] != ShapedType::kDynamic &&
5047 multiplierShape[0] != numChannels) {
5048 emitOpError("expect shape of { ")
5049 << numChannels << " } for multiplier input, got { "
5050 << multiplierShape[0] << " }";
5051 return failure();
5052 }
5053 }
5054
5055 const auto shiftType = llvm::cast<ShapedType>(getShift().getType());
5056 if (shiftType.hasRank()) {
5057 ArrayRef<int64_t> shiftShape = shiftType.getShape();
5058 // shift input has rank 1 by dialect definition
5059 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
5060 emitOpError("expect shape of { ")
5061 << numChannels << " } for shift input, got { " << shiftShape[0]
5062 << " }";
5063 return failure();
5064 }
5065 }
5066
5067 return success();
5068}
5069
5070LogicalResult RescaleOp::inferReturnTypeComponents(
5071 MLIRContext *context, ::std::optional<Location> location,
5072 RescaleOp::Adaptor adaptor,
5073 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5074 ShapeAdaptor inputShape(adaptor.getInput().getType());
5075 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5076 return success();
5077}
5078
5079LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
5080 MLIRContext *context, ::std::optional<Location> location,
5081 CastFromBlockScaledOp::Adaptor adaptor,
5082 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5083 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5084 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5085 return success();
5086}
5087
5088LogicalResult CastFromBlockScaledOp::verify() {
5089 const Type inputDataType = getInputData().getType();
5090 const Type outputDataType = getResult().getType();
5091 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
5092 return emitOpError() << "require compatible shapes for input_data ("
5093 << inputDataType << ") and " << "output_data ("
5094 << outputDataType << ")";
5095
5096 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5097
5098 if (inputDataShape.hasRank()) {
5099 const unsigned int blockSize =
5100 BlockSizeAttr::getBlockSizeValue(getBlockSize());
5101 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5102 return emitOpError("expect block size to be 32, got ") << blockSize;
5103 const int64_t inputDataLastDim =
5104 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
5105 if (inputDataLastDim % blockSize != 0)
5106 return emitOpError() << "expect last dimension of input_data ("
5107 << inputDataLastDim
5108 << ") to be divisible by block_size (" << blockSize
5109 << ")";
5110
5111 const Type inputScaleType = getInputScale().getType();
5112 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
5113
5114 if (inputScaleShape.hasRank()) {
5115 SmallVector<int64_t> inputDataDims, inputScaleDims;
5116 inputDataShape.getDims(inputDataDims);
5117 inputScaleShape.getDims(inputScaleDims);
5118
5119 if (inputDataDims.size() != inputScaleDims.size() ||
5121 ArrayRef<int64_t>(inputDataDims).drop_back(1),
5122 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
5123 return emitOpError()
5124 << "require compatible shapes for input_data (" << inputDataType
5125 << ") and " << "input_scale (" << inputScaleType
5126 << ") except for the last dimension";
5127
5128 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
5129 inputScaleDims.back()};
5130 if (ShapedType::isStatic(inputDataLastDim) &&
5131 failed(verifyCompatibleDims(dimsToCheck)))
5132 return emitOpError()
5133 << "expect last dimension of input_scale ("
5134 << inputScaleDims.back()
5135 << ") to be equal to last dimension of input_data / block_size ("
5136 << inputDataDims.back() / blockSize << ")";
5137 }
5138 }
5139
5140 return success();
5141}
5142
5143LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
5144 MLIRContext *context, ::std::optional<Location> location,
5145 CastToBlockScaledOp::Adaptor adaptor,
5146 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5147 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5148 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5149 if (!inputShape.hasRank())
5150 return success();
5151
5152 // Calculate output_scale shape if ranked input provided
5153 SmallVector<int64_t> outputScaleShape;
5154 inputShape.getDims(outputScaleShape);
5155 const int64_t lastDimLoc = inputShape.getRank() - 1;
5156 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
5157 if (ShapedType::isStatic(lastDimSize)) {
5158 const unsigned int blockSize =
5159 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
5160 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
5161 }
5162 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
5163 return success();
5164}
5165
5166LogicalResult CastToBlockScaledOp::verify() {
5167 const Type inputDataType = getInputData().getType();
5168 const Type outputDataType = getResult(0).getType();
5169 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
5170 return emitOpError() << "require compatible shapes for input_data ("
5171 << inputDataType << ") and " << "output_data ("
5172 << outputDataType << ")";
5173
5174 const unsigned int blockSize =
5175 BlockSizeAttr::getBlockSizeValue(getBlockSize());
5176 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5177 return emitOpError("expect block size to be 32, got ") << blockSize;
5178 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5179 if (inputDataShape.hasRank()) {
5180 const int64_t inputDataLastDim =
5181 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
5182 if (ShapedType::isStatic(inputDataLastDim) &&
5183 inputDataLastDim % blockSize != 0)
5184 return emitOpError() << "expect last dimension of input_data ("
5185 << inputDataLastDim
5186 << ") to be divisible by block_size (" << blockSize
5187 << ")";
5188 }
5189
5190 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
5191 const Type outputScaleType = getResult(1).getType();
5192 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
5193 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
5194 SmallVector<int64_t> outputDataDims, outputScaleDims;
5195 outputDataShape.getDims(outputDataDims);
5196 outputScaleShape.getDims(outputScaleDims);
5197
5198 if (outputDataDims.size() != outputScaleDims.size() ||
5200 ArrayRef<int64_t>(outputDataDims).drop_back(1),
5201 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
5202 return emitOpError() << "require compatible shapes for output_data ("
5203 << outputDataType << ") and " << "output_scale ("
5204 << outputScaleType
5205 << ") except for the last dimension";
5206
5207 const int64_t outputDataLastDim = outputDataDims.back();
5208 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
5209 outputScaleDims.back()};
5210 if (ShapedType::isStatic(outputDataLastDim) &&
5211 failed(verifyCompatibleDims(dimsToCheck)))
5212 return emitOpError()
5213 << "expect last dimension of output_scale ("
5214 << outputScaleDims.back()
5215 << ") to be equal to last dimension of output_data / block_size ("
5216 << outputDataDims.back() / blockSize << ")";
5217 }
5218
5219 return success();
5220}
5221
5222LogicalResult IfOp::inferReturnTypeComponents(
5223 MLIRContext *context, ::std::optional<Location> location,
5224 IfOp::Adaptor adaptor,
5225 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5226 llvm::SmallVector<tosa::YieldOp> yieldOps;
5227 for (Region *region : adaptor.getRegions()) {
5228 for (auto &block : *region)
5229 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5230 yieldOps.push_back(returnOp);
5231 }
5232
5233 if (yieldOps.empty())
5234 return failure();
5235
5236 // Get the initial type information for the yield op.
5237 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5238 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5239 for (auto operand : yieldOps.front().getOperands()) {
5240 resultKnowledge.push_back(
5241 ValueKnowledge::getKnowledgeFromType(operand.getType()));
5242 }
5243
5244 for (auto yieldOp : yieldOps) {
5245 if (resultKnowledge.size() != yieldOp.getNumOperands())
5246 return failure();
5247
5248 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5249 int32_t index = it.index();
5250 auto meet = ValueKnowledge::meet(
5251 resultKnowledge[index],
5252 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
5253 if (!meet)
5254 continue;
5255 resultKnowledge[index] = meet;
5256 }
5257 }
5258
5259 for (const ValueKnowledge &result : resultKnowledge) {
5260 inferredReturnShapes.push_back(result.getShapedTypeComponents());
5261 }
5262
5263 return success();
5264}
5265
5266LogicalResult WhileOp::inferReturnTypeComponents(
5267 MLIRContext *context, ::std::optional<Location> location,
5268 WhileOp::Adaptor adaptor,
5269 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5270 llvm::SmallVector<tosa::YieldOp> yieldOps;
5271 for (auto &block : adaptor.getBodyGraph())
5272 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5273 yieldOps.push_back(returnOp);
5274
5275 // TOSA's while must have a tosa.yield as its terminator. If not found this
5276 // tosa.while is invalid.
5277 if (yieldOps.empty())
5278 return failure();
5279
5280 // Get the initial type information from the operand types.
5281 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5282 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5283 for (auto operand : yieldOps.front().getOperands()) {
5284 resultKnowledge.push_back(
5285 ValueKnowledge::getKnowledgeFromType(operand.getType()));
5286 }
5287
5288 for (auto yieldOp : yieldOps) {
5289 if (resultKnowledge.size() != yieldOp.getNumOperands())
5290 return failure();
5291
5292 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5293 int32_t index = it.index();
5294 if (auto meet = ValueKnowledge::meet(
5295 resultKnowledge[index],
5296 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
5297 resultKnowledge[index] = meet;
5298 }
5299 }
5300 }
5301
5302 for (const ValueKnowledge &result : resultKnowledge) {
5303 inferredReturnShapes.push_back(result.getShapedTypeComponents());
5304 }
5305
5306 return success();
5307}
5308
5309std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
5310 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
5311 return llvm::to_vector<4>(vt.getShape());
5312 return std::nullopt;
5313}
5314
5316 Block::BlockArgListType blocksArgs,
5317 ValueRange initializers,
5318 StringRef prefix = "") {
5319 assert(blocksArgs.size() == initializers.size() &&
5320 "expected same length of arguments and initializers");
5321 if (initializers.empty())
5322 return;
5323
5324 parser << prefix << '(';
5325 llvm::interleaveComma(
5326 llvm::zip(blocksArgs, initializers), parser,
5327 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
5328 parser << ")";
5329}
5330
5331// parse and print of IfOp refer to the implementation of SCF dialect.
5332ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
5333 // Create the regions for 'then'.
5334 result.regions.reserve(2);
5335 Region *thenRegion = result.addRegion();
5336 Region *elseRegion = result.addRegion();
5337
5338 OpAsmParser::UnresolvedOperand cond;
5339
5340 if (parser.parseOperand(cond))
5341 return failure();
5342
5343 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5344 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5345
5346 // Parse the optional block arguments
5347 OptionalParseResult listResult =
5348 parser.parseOptionalAssignmentList(regionArgs, operands);
5349 if (listResult.has_value() && failed(listResult.value()))
5350 return failure();
5351
5352 // Parse a colon.
5353 if (failed(parser.parseColon()))
5354 return parser.emitError(parser.getCurrentLocation(),
5355 "expected type for condition operand");
5356
5357 // Parse the type of the condition operand
5358 Type condType;
5359 if (failed(parser.parseType(condType)))
5360 return parser.emitError(parser.getCurrentLocation(),
5361 "expected type for condition operand");
5362
5363 // Resolve operand with provided type
5364 if (failed(parser.resolveOperand(cond, condType, result.operands)))
5365 return failure();
5366
5367 // Parse optional block arg types
5368 if (listResult.has_value()) {
5369 FunctionType functionType;
5370
5371 if (failed(parser.parseType(functionType)))
5372 return parser.emitError(parser.getCurrentLocation())
5373 << "expected list of types for block arguments "
5374 << "followed by arrow type and list of return types";
5375
5376 result.addTypes(functionType.getResults());
5377
5378 if (functionType.getNumInputs() != operands.size()) {
5379 return parser.emitError(parser.getCurrentLocation())
5380 << "expected as many input types as operands " << "(expected "
5381 << operands.size() << " got " << functionType.getNumInputs()
5382 << ")";
5383 }
5384
5385 // Resolve input operands.
5386 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
5387 parser.getCurrentLocation(),
5388 result.operands)))
5389 return failure();
5390 } else {
5391 // Parse optional results type list.
5392 if (parser.parseOptionalArrowTypeList(result.types))
5393 return failure();
5394 }
5395
5396 // Parse the 'then' region.
5397 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
5398 return failure();
5399
5400 // If we find an 'else' keyword then parse the 'else' region.
5401 if (!parser.parseOptionalKeyword("else")) {
5402 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
5403 return failure();
5404 }
5405
5406 // Parse the optional attribute list.
5407 if (parser.parseOptionalAttrDict(result.attributes))
5408 return failure();
5409 return success();
5410}
5411
5412void IfOp::print(OpAsmPrinter &p) {
5413 p << " " << getCondition();
5414
5415 printInitializationList(p, getThenGraph().front().getArguments(),
5416 getInputList(), " ");
5417 p << " : ";
5418 p << getCondition().getType();
5419
5420 if (!getInputList().empty()) {
5421 p << " (";
5422 llvm::interleaveComma(getInputList().getTypes(), p);
5423 p << ")";
5424 }
5425 p.printArrowTypeList(getResultTypes());
5426 p << " ";
5427
5428 p.printRegion(getThenGraph());
5429
5430 // Print the 'else' regions if it exists and has a block.
5431 auto &elseRegion = getElseGraph();
5432 if (!elseRegion.empty()) {
5433 p << " else ";
5434 p.printRegion(elseRegion);
5435 }
5436
5437 p.printOptionalAttrDict((*this)->getAttrs());
5438}
5439
5440LogicalResult IfOp::verify() {
5441 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
5442 "'then_graph' arguments", getInputList(),
5443 "'input_list'")
5444 .failed())
5445 return failure();
5446
5447 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
5448 "'else_graph' arguments", getInputList(),
5449 "'input_list'")
5450 .failed())
5451 return failure();
5452
5453 // MLIR will verify the absence of the terminator for us if otherwise.
5454 if (getThenGraph().front().mightHaveTerminator()) {
5455 auto thenYield =
5456 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
5457 if (thenYield && errorIfTypeOrShapeMismatch(
5458 *this, thenYield.getInputs(), "'then_graph' results",
5459 getOutputList(), "'output_list'")
5460 .failed())
5461 return failure();
5462 }
5463
5464 // MLIR will verify the absence of the terminator for us if otherwise.
5465 if (getElseGraph().front().mightHaveTerminator()) {
5466 auto elseYield =
5467 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
5468 if (elseYield && errorIfTypeOrShapeMismatch(
5469 *this, elseYield.getInputs(), "'else_graph' results",
5470 getOutputList(), "'output_list'")
5471 .failed())
5472 return failure();
5473 }
5474
5475 auto condType = getCondition().getType();
5476 if (errorIfShapeNotSizeOne(*this, condType).failed())
5477 return emitOpError() << "'condition' must be a size 1 tensor, got "
5478 << condType;
5479
5480 return success();
5481}
5482
5483LogicalResult WhileOp::verify() {
5484 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
5485 getOutputList(), "'output_list'")
5486 .failed())
5487 return failure();
5488
5489 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
5490 "'cond_graph' arguments", getInputList(),
5491 "'input_list'")
5492 .failed())
5493 return failure();
5494
5495 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
5496 "'body_graph' arguments", getInputList(),
5497 "'input_list'")
5498 .failed())
5499 return failure();
5500
5501 if (getBodyGraph().front().mightHaveTerminator()) {
5502 auto bodyYield =
5503 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
5504 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
5505 "'body_graph' results",
5506 getInputList(), "'input_list'")
5507 .failed())
5508 return failure();
5509 }
5510
5511 // Condition block output must be a single element tensor with a single bool
5512 // value.
5513 if (!getCondGraph().front().mightHaveTerminator())
5514 return success();
5515
5516 auto condYield =
5517 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
5518 if (!condYield)
5519 return success();
5520
5521 if (condYield.getInputs().size() != 1)
5522 return emitOpError() << "require 'cond_graph' only have one result";
5523
5524 auto condOutType = condYield.getInputs()[0].getType();
5525 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
5526 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
5527 << condOutType;
5528
5529 if (!getElementTypeOrSelf(condOutType).isInteger(1))
5530 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
5531 << condOutType;
5532
5533 return success();
5534}
5535
5536LogicalResult ReverseOp::verify() {
5537 TensorType inputType = getInput1().getType();
5538 int32_t reverseAxis = getAxis();
5539
5540 if (reverseAxis < 0)
5541 return emitOpError("expected non-negative reverse axis");
5542 if (inputType.hasRank()) {
5543 int64_t inputRank = inputType.getRank();
5544 // We allow for a special case where the input/output shape has rank 0 and
5545 // axis is also 0.
5546 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
5547 return emitOpError("expect input tensor rank (")
5548 << inputRank << ") to be larger than reverse axis (" << reverseAxis
5549 << ")";
5550 }
5551
5552 return success();
5553}
5554
5555LogicalResult tosa::SelectOp::verify() {
5556 // verify input2 and input3 have same element type as output
5557 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
5558 /* outType = */ getOutput().getType())
5559 .failed() ||
5560 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
5561 /* outType = */ getOutput().getType())
5562 .failed()) {
5563 return failure();
5564 }
5565 // verify input1 has element type of bool
5566 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
5567 if (!predicateType) {
5568 return emitOpError("expect shaped tensor for input1, got ")
5569 << getInput1().getType();
5570 }
5571 auto predicateElementType = predicateType.getElementType();
5572 if (!predicateElementType.isInteger(1)) {
5573 return emitOpError("expect element type of bool for input1, got ")
5574 << predicateElementType;
5575 }
5576
5577 return success();
5578}
5579
5580LogicalResult tosa::VariableReadOp::verify() {
5581 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
5582 .failed())
5583 return failure();
5584
5585 return success();
5586}
5587
5588LogicalResult tosa::VariableWriteOp::verify() {
5589 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
5590 .failed())
5591 return failure();
5592
5593 return success();
5594}
5595
5596// parse and print of WhileOp refer to the implementation of SCF dialect.
5597ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
5598 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5599 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5600 Region *cond = result.addRegion();
5601 Region *body = result.addRegion();
5602
5603 OptionalParseResult listResult =
5604 parser.parseOptionalAssignmentList(regionArgs, operands);
5605 if (listResult.has_value() && failed(listResult.value()))
5606 return failure();
5607
5608 FunctionType functionType;
5609 SMLoc typeLoc = parser.getCurrentLocation();
5610 if (failed(parser.parseColonType(functionType)))
5611 return failure();
5612
5613 result.addTypes(functionType.getResults());
5614
5615 if (functionType.getNumInputs() != operands.size()) {
5616 return parser.emitError(typeLoc)
5617 << "expected as many input types as operands " << "(expected "
5618 << operands.size() << " got " << functionType.getNumInputs() << ")";
5619 }
5620
5621 // Resolve input operands.
5622 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
5623 parser.getCurrentLocation(),
5624 result.operands)))
5625 return failure();
5626
5627 // Propagate the types into the region arguments.
5628 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
5629 regionArgs[i].type = functionType.getInput(i);
5630
5631 return failure(parser.parseRegion(*cond, regionArgs) ||
5632 parser.parseKeyword("do") || parser.parseRegion(*body) ||
5633 parser.parseOptionalAttrDictWithKeyword(result.attributes));
5634}
5635
5636void WhileOp::print(OpAsmPrinter &parser) {
5637 printInitializationList(parser, getCondGraph().front().getArguments(),
5638 getInputList(), " ");
5639 parser << " : ";
5640 parser.printFunctionalType(getInputList().getTypes(),
5641 getResults().getTypes());
5642 parser << ' ';
5643 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
5644 parser << " do ";
5645 parser.printRegion(getBodyGraph());
5646 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
5647}
5648
5649// Create a rank-1 const tensor for zero point of the source tensor.
5650std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
5651 Location loc,
5652 Type srcElemType,
5653 int64_t zp) {
5654 srcElemType = getStorageElementTypeOrSelf(srcElemType);
5655 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
5656 if (llvm::isa<FloatType>(srcElemType)) {
5657 auto zpAttr = DenseElementsAttr::get(
5658 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
5659 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5660 }
5661 if (llvm::isa<IntegerType>(srcElemType)) {
5662 auto zpAttr =
5663 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
5664 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5665 }
5666 llvm::errs() << "zero point is not allowed for unsupported data types\n";
5667 return std::nullopt;
5668}
5669
5670//===----------------------------------------------------------------------===//
5671// TOSA Shape and Shape Operators Helper functions.
5672//===----------------------------------------------------------------------===//
5673
5675 return mlir::isa<tosa::shapeType>(t);
5676}
5677
5678LogicalResult
5679mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
5680 int rank) {
5681 if (rank < 0)
5682 return emitError() << "invalid rank (must be >= 0): " << rank;
5683 return success();
5684}
5685
5687 for (auto v : op->getOperands()) {
5688 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
5689 Operation *definingOp = v.getDefiningOp();
5690 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
5691 return op->emitOpError("shape operand is not compile time resolvable");
5692 }
5693 }
5694 }
5695 return success();
5696}
5697
5698LogicalResult
5700 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
5701 return failure();
5702
5703 // delegate function that returns rank of shape type
5704 auto getRank = [](const Type type) {
5705 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
5706 };
5707 auto operandTypes = op->getOperandTypes();
5708 auto resultTypes = op->getResultTypes();
5709
5710 auto rank = getRank(*op->getOperandTypes().begin());
5711 for (auto type : operandTypes) {
5712 if (getRank(type) != rank) {
5713 return op->emitOpError("operands don't have matching ranks");
5714 }
5715 }
5716 for (auto type : resultTypes) {
5717 if (getRank(type) != rank) {
5718 return op->emitOpError("result shape has different rank than operands");
5719 }
5720 }
5721 return success();
5722}
5723
5724//===----------------------------------------------------------------------===//
5725// TOSA Shape Operators verify functions.
5726//===----------------------------------------------------------------------===//
5727
5728LogicalResult tosa::ConstShapeOp::verify() {
5729 // check one dimensional rank
5730 auto valuesRank = getValues().getType().getRank();
5731 if (valuesRank != 1)
5732 return emitOpError("expect elements in attribute values with rank 1");
5733 // check that number of elements in values attr equal to rank of result shape
5734 auto count = getValues().getNumElements();
5735 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
5736 if (count != rank && (count != 1 || rank != 0)) {
5737 return emitOpError("expect number of elements in attribute values (")
5738 << count << ") to be equal to the rank (" << rank
5739 << ") for the result shape type";
5740 }
5741 return success();
5742}
5743
5744LogicalResult tosa::DimOp::verify() {
5745 const tosa::shapeType outShapeType =
5746 cast<tosa::shapeType>(getResult().getType());
5747 if (outShapeType.getRank() != 1)
5748 return emitOpError("expect output shape type to contain one element, got ")
5749 << outShapeType;
5750
5751 const ShapeAdaptor inputType(getInput1().getType());
5752 if (inputType.hasRank()) {
5753 const int64_t inputRank = inputType.getRank();
5754 const int64_t axis = getAxisAttr().getInt();
5755 if (axis < 0 || axis >= inputRank)
5756 return emitOpError("expect axis to be in the range [0, ")
5757 << inputRank << "), got " << axis;
5758 }
5759 return success();
5760}
5761
5762LogicalResult tosa::ConcatShapeOp::verify() {
5763 const tosa::shapeType outShapeType =
5764 cast<tosa::shapeType>(getResult().getType());
5765 const int64_t outputRank = outShapeType.getRank();
5766 const Operation::operand_range inputList = getInput();
5767
5768 if (inputList.size() == 0)
5769 return emitOpError("requires at least one input shape");
5770
5771 if (llvm::any_of(inputList, [](Value v) {
5772 return cast<tosa::shapeType>(v.getType()).getRank() == 0;
5773 }))
5774 return emitOpError("requires all inputs shapes have a rank greater than 0");
5775
5776 const int64_t inputsRank =
5777 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
5778 const tosa::shapeType inShapeType =
5779 cast<tosa::shapeType>(input.getType());
5780 return acc + inShapeType.getRank();
5781 });
5782 if (outputRank != inputsRank)
5783 return emitOpError("requires output shape rank to be equal to the sum of "
5784 "the input shape ranks (")
5785 << inputsRank << "), got " << outputRank;
5786
5787 return success();
5788}
5789
5790LogicalResult tosa::SliceShapeOp::verify() {
5791 std::optional<int32_t> start;
5792 DenseIntElementsAttr startAttr;
5793 if (matchPattern(getStart(), m_Constant(&startAttr)))
5794 start = startAttr.getValues<int32_t>()[0];
5795 if (start && start.value() < 0)
5796 return emitOpError("expected non-negative start index, got ")
5797 << start.value();
5798
5799 std::optional<int32_t> size;
5800 DenseIntElementsAttr sizeAttr;
5801 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
5802 size = sizeAttr.getValues<int32_t>()[0];
5803 if (size && size.value() <= 0)
5804 return emitOpError("expected positive size, got ") << size.value();
5805
5806 if (!size)
5807 return success();
5808
5809 const tosa::shapeType outShapeType =
5810 cast<tosa::shapeType>(getResult().getType());
5811 const int64_t outputRank = outShapeType.getRank();
5812 if (outputRank != size)
5813 return emitOpError(
5814 "expected output type size to be equal to size attribute, got ")
5815 << outputRank << " vs " << size.value();
5816
5817 if (!start)
5818 return success();
5819
5820 const tosa::shapeType inShapeType =
5821 cast<tosa::shapeType>(getInput().getType());
5822 const int64_t inputRank = inShapeType.getRank();
5823 const int64_t sliceSize = start.value() + size.value();
5824 if (sliceSize > inputRank)
5825 return emitOpError("expected start + size to be less than or equal to "
5826 "input shape rank (")
5827 << inputRank << "), got " << sliceSize;
5828
5829 return success();
5830}
5831
5832//===----------------------------------------------------------------------===//
5833// TOSA Attribute Definitions.
5834//===----------------------------------------------------------------------===//
5835
5836#define GET_ATTRDEF_CLASSES
5837#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5838
5839//===----------------------------------------------------------------------===//
5840// TOSA Type Definitions.
5841//===----------------------------------------------------------------------===//
5842#define GET_TYPEDEF_CLASSES
5843#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5844
5845//===----------------------------------------------------------------------===//
5846// TOSA Operator Definitions.
5847//===----------------------------------------------------------------------===//
5848
5849#define GET_OP_CLASSES
5850#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
static void printShapeToDiagnostic(InFlightDiagnostic &diag, ArrayRef< int64_t > shape)
Definition TosaOps.cpp:653
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1451
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1077
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4387
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:139
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3935
static void buildAvgPool2dAdaptiveOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad, TypeAttr accType)
This builder mirrors avg_pool2d quant-info handling and materializes kernel/stride/pad as const_shape...
Definition TosaOps.cpp:1524
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:595
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:1041
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3960
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:742
static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op)
Definition TosaOps.cpp:1240
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:1050
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4150
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1611
static LogicalResult verifyPoolingOpImpl(Operation *op, ArrayRef< int64_t > kernel, ArrayRef< int64_t > strides, ArrayRef< int64_t > padding, Value input, Value output)
Definition TosaOps.cpp:1142
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:578
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1625
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:682
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3985
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:4053
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:3173
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1429
static void extractAdaptivePoolingConstShapeOperands(T op, AdaptivePoolingConstShapeValues &values)
Definition TosaOps.cpp:1298
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:896
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:3104
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp
Definition TosaOps.cpp:1291
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:638
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:847
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4041
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:584
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3951
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1669
static LogicalResult verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType, ArrayRef< int64_t > expectedShape, StringRef outputName="output")
Definition TosaOps.cpp:666
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1571
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1405
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1480
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:999
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
Definition TosaOps.cpp:1655
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:3131
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1234
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1755
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static void updateIfDynamic(int64_t &current, int64_t candidate)
Definition TosaOps.cpp:4189
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4282
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4311
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4256
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
Definition TosaOps.cpp:4253
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4202
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4217
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
Definition TosaOps.cpp:4199
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4235
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4352
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
Definition TosaOps.cpp:4332
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4335
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4372
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:82
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
ResultRange result_range
Support result iteration.
Definition Operation.h:435
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OperandRange operand_range
Definition Operation.h:396
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:5699
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:5686
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:228
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:106
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:253
FailureOr< T > getConstantScalarIntValue(Value val)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:630
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:5650
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:5674
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:615
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45